71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
# app.py
|
|
import logging
|
|
import pandas as pd
|
|
import numpy as np
|
|
import yfinance as yf
|
|
from flask import Flask, request, jsonify
|
|
from flask_cors import CORS
|
|
from datetime import datetime, timedelta
|
|
from sklearn.preprocessing import StandardScaler
|
|
from model import predict
|
|
|
|
import torch
|
|
from stable_baselines3 import PPO
|
|
from transformers import PatchTSTForClassification
|
|
|
|
# === Model Config ===
|
|
CONTEXT_LENGTH = 48
|
|
LABEL_NAMES = ['SELL', 'HOLD', 'BUY']
|
|
|
|
# ============ Logging ============
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ============ Flask App ============
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
|
|
# ============ Feature Extraction ============
|
|
def extract_features(df):
|
|
import talib
|
|
|
|
df['return'] = df['Close'].pct_change()
|
|
df['daily_momentum'] = df['Close'] / df['Open']
|
|
df['range_efficiency'] = abs(df['Close'] - df['Open']) / (df['High'] - df['Low']).replace(0, np.nan)
|
|
df['volume_momentum'] = df['Volume'] / df['Volume'].rolling(5).mean().replace(0, np.nan)
|
|
df['adx'] = talib.ADX(df['High'], df['Low'], df['Close'], timeperiod=14) / 100.0
|
|
df['rsi_scaled'] = talib.RSI(df['Close'], timeperiod=14) / 100.0
|
|
return df.dropna().reset_index(drop=True)
|
|
|
|
|
|
def extract_patch_features(model, X):
|
|
model.eval()
|
|
features = []
|
|
with torch.no_grad():
|
|
for i in range(0, len(X), 64):
|
|
batch = torch.tensor(X[i:i+64], dtype=torch.float32)
|
|
outputs = model.base_model(
|
|
past_values=batch,
|
|
output_hidden_states=True,
|
|
return_dict=True
|
|
)
|
|
cls_tokens = outputs.hidden_states[-1][:, 0, :]
|
|
features.append(cls_tokens)
|
|
return torch.cat(features, dim=0).numpy()
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def predict_by_clicked_date_yfinance():
|
|
data = request.get_json()
|
|
symbol = data.get("symbol")
|
|
clicked_date_str = data.get("clicked_date")
|
|
result = predict(clicked_date_str, symbol)
|
|
return jsonify(result)
|
|
|
|
# ============ Run App ============
|
|
if __name__ == '__main__':
|
|
app.run(debug=True, port=8000)
|