Files
SKRIPSI/web-model/index.py
reizha 831085d8b3 feat
2025-07-11 19:14:55 +07:00

71 lines
2.1 KiB
Python

# app.py
import logging
import pandas as pd
import numpy as np
import yfinance as yf
from flask import Flask, request, jsonify
from flask_cors import CORS
from datetime import datetime, timedelta
from sklearn.preprocessing import StandardScaler
from model import predict
import torch
from stable_baselines3 import PPO
from transformers import PatchTSTForClassification
# === Model Config ===
CONTEXT_LENGTH = 48
LABEL_NAMES = ['SELL', 'HOLD', 'BUY']
# ============ Logging ============
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ============ Flask App ============
app = Flask(__name__)
CORS(app)
# ============ Feature Extraction ============
def extract_features(df):
import talib
df['return'] = df['Close'].pct_change()
df['daily_momentum'] = df['Close'] / df['Open']
df['range_efficiency'] = abs(df['Close'] - df['Open']) / (df['High'] - df['Low']).replace(0, np.nan)
df['volume_momentum'] = df['Volume'] / df['Volume'].rolling(5).mean().replace(0, np.nan)
df['adx'] = talib.ADX(df['High'], df['Low'], df['Close'], timeperiod=14) / 100.0
df['rsi_scaled'] = talib.RSI(df['Close'], timeperiod=14) / 100.0
return df.dropna().reset_index(drop=True)
def extract_patch_features(model, X):
model.eval()
features = []
with torch.no_grad():
for i in range(0, len(X), 64):
batch = torch.tensor(X[i:i+64], dtype=torch.float32)
outputs = model.base_model(
past_values=batch,
output_hidden_states=True,
return_dict=True
)
cls_tokens = outputs.hidden_states[-1][:, 0, :]
features.append(cls_tokens)
return torch.cat(features, dim=0).numpy()
@app.route('/predict', methods=['POST'])
def predict_by_clicked_date_yfinance():
data = request.get_json()
symbol = data.get("symbol")
clicked_date_str = data.get("clicked_date")
result = predict(clicked_date_str, symbol)
return jsonify(result)
# ============ Run App ============
if __name__ == '__main__':
app.run(debug=True, port=8000)