feat
This commit is contained in:
70
web-model/index.py
Normal file
70
web-model/index.py
Normal file
@ -0,0 +1,70 @@
|
||||
# app.py
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yfinance as yf
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from datetime import datetime, timedelta
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from model import predict
|
||||
|
||||
import torch
|
||||
from stable_baselines3 import PPO
|
||||
from transformers import PatchTSTForClassification
|
||||
|
||||
# === Model Config ===
|
||||
CONTEXT_LENGTH = 48
|
||||
LABEL_NAMES = ['SELL', 'HOLD', 'BUY']
|
||||
|
||||
# ============ Logging ============
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ Flask App ============
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
|
||||
# ============ Feature Extraction ============
|
||||
def extract_features(df):
|
||||
import talib
|
||||
|
||||
df['return'] = df['Close'].pct_change()
|
||||
df['daily_momentum'] = df['Close'] / df['Open']
|
||||
df['range_efficiency'] = abs(df['Close'] - df['Open']) / (df['High'] - df['Low']).replace(0, np.nan)
|
||||
df['volume_momentum'] = df['Volume'] / df['Volume'].rolling(5).mean().replace(0, np.nan)
|
||||
df['adx'] = talib.ADX(df['High'], df['Low'], df['Close'], timeperiod=14) / 100.0
|
||||
df['rsi_scaled'] = talib.RSI(df['Close'], timeperiod=14) / 100.0
|
||||
return df.dropna().reset_index(drop=True)
|
||||
|
||||
|
||||
def extract_patch_features(model, X):
|
||||
model.eval()
|
||||
features = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(X), 64):
|
||||
batch = torch.tensor(X[i:i+64], dtype=torch.float32)
|
||||
outputs = model.base_model(
|
||||
past_values=batch,
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
cls_tokens = outputs.hidden_states[-1][:, 0, :]
|
||||
features.append(cls_tokens)
|
||||
return torch.cat(features, dim=0).numpy()
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict_by_clicked_date_yfinance():
|
||||
data = request.get_json()
|
||||
symbol = data.get("symbol")
|
||||
clicked_date_str = data.get("clicked_date")
|
||||
result = predict(clicked_date_str, symbol)
|
||||
return jsonify(result)
|
||||
|
||||
# ============ Run App ============
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, port=8000)
|
Reference in New Issue
Block a user