import io, os, re, sys
import pandas as pd
from typing import List, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(BASE_DIR, '..'))
KRONOS_PATH = os.path.join(PROJECT_ROOT, 'Kronos')
if KRONOS_PATH not in sys.path:
    sys.path.insert(0, KRONOS_PATH)

from model import Kronos, KronosTokenizer, KronosPredictor   # Kronos 仓库

app = FastAPI(title="Kronos Web API", version="1.8")
app.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"],
)

# —— 可选：本地模型优先 —— #
MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
def _prefer_local(repo_name: str) -> str:
    local = os.path.join(MODELS_DIR, repo_name.split("/")[-1])
    return local if os.path.isdir(local) else repo_name

SPEC = {
    "mini":  {"model": _prefer_local("NeoQuasar/Kronos-mini"),
              "tokenizer": _prefer_local("NeoQuasar/Kronos-Tokenizer-2k"),
              "max_ctx": 2048},
    "small": {"model": _prefer_local("NeoQuasar/Kronos-small"),
              "tokenizer": _prefer_local("NeoQuasar/Kronos-Tokenizer-base"),
              "max_ctx": 512},
    "base":  {"model": _prefer_local("NeoQuasar/Kronos-base"),
              "tokenizer": _prefer_local("NeoQuasar/Kronos-Tokenizer-base"),
              "max_ctx": 512},
}
MODEL_CACHE = {"name": None, "tokenizer": None, "model": None}

class PredictResponse(BaseModel):
    model_name: str
    device: str
    lookback: int
    pred_len: int
    x_timestamps: List[str]
    y_timestamps: List[str]
    actual_close: Optional[List[float]] = None
    pred_close: List[float]

def load_model_if_needed(size_key: str):
    if size_key not in SPEC:
        raise HTTPException(status_code=400, detail=f"Unsupported model size: {size_key}")
    mspec = SPEC[size_key]
    if MODEL_CACHE["name"] == mspec["model"] and MODEL_CACHE["model"] is not None:
        return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"], mspec["max_ctx"], mspec
    tokenizer = KronosTokenizer.from_pretrained(mspec["tokenizer"])
    model = Kronos.from_pretrained(mspec["model"])
    MODEL_CACHE.update({"name": mspec["model"], "tokenizer": tokenizer, "model": model})
    return tokenizer, model, mspec["max_ctx"], mspec

def _normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [c.strip().lower() for c in df.columns]
    return df

def _parse_csv(content: bytes) -> pd.DataFrame:
    try:
        return pd.read_csv(io.BytesIO(content))
    except Exception:
        return pd.read_csv(io.BytesIO(content), sep=None, engine="python")

def _ensure_time_column(df: pd.DataFrame) -> pd.DataFrame:
    if "timestamps" in df.columns:
        try:
            df["timestamps"] = pd.to_datetime(df["timestamps"], errors="raise")
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Column 'timestamps' cannot be parsed: {e}")
    else:
        df["timestamps"] = pd.date_range("2000-01-01", periods=len(df), freq="H")
    return df

# ---------------- 上传 CSV 预测 ----------------
@app.post("/api/predict", response_model=PredictResponse)
async def predict(
    file: UploadFile = File(...),
    model_size: str = Form("small"),
    device: str = Form("cpu"),
    lookback: int = Form(400),
    pred_len: int = Form(120),
):
    try:
        content = await file.read()
        if not content:
            raise HTTPException(status_code=400, detail="Empty file")
        df = _parse_csv(content)
        df = _normalize_columns(df)

        need = ["open","high","low","close"]
        miss = [c for c in need if c not in df.columns]
        if miss: raise HTTPException(status_code=400, detail=f"Missing columns: {miss}. Required: {need}")

        df = _ensure_time_column(df)
        if device != "cpu" and not re.match(r"^cuda:\d+$", device):
            raise HTTPException(status_code=400, detail="device must be 'cpu' or like 'cuda:0'")

        tokenizer, model, max_ctx, mspec = load_model_if_needed(model_size)
        lookback = min(int(lookback), max_ctx)
        pred_len = int(pred_len)
        if len(df) < lookback + pred_len:
            raise HTTPException(status_code=400, detail="Data too short for given lookback + pred_len")

        cols = ["open","high","low","close"]
        if "volume" in df.columns: cols.append("volume")
        if "amount" in df.columns: cols.append("amount")

        x_df = df.loc[:lookback-1, cols]
        x_ts = df.loc[:lookback-1, "timestamps"]
        y_ts = df.loc[lookback:lookback+pred_len-1, "timestamps"]

        predictor = KronosPredictor(model, tokenizer, device=device, max_context=max_ctx)
        pred_df = predictor.predict(
            df=x_df, x_timestamp=x_ts, y_timestamp=y_ts,
            pred_len=pred_len, T=1.0, top_p=0.9, sample_count=1
        )
        pred_close = pred_df["close"].tolist() if "close" in pred_df.columns else pred_df.iloc[:,0].tolist()
        actual_close = df.loc[lookback:lookback+pred_len-1, "close"].tolist()

        return PredictResponse(
            model_name=mspec["model"], device=device,
            lookback=lookback, pred_len=pred_len,
            x_timestamps=x_ts.astype(str).tolist(),
            y_timestamps=y_ts.astype(str).tolist(),
            actual_close=actual_close, pred_close=pred_close,
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {type(e).__name__}: {e}")

# ---------------- 预热（下载模型） ----------------
@app.post("/api/warmup")
async def warmup(model_size: str = Form("small")):
    _, __, max_ctx, mspec = load_model_if_needed(model_size)
    return {"status": "ok", "model_size": model_size, "max_context": max_ctx, "model": mspec["model"]}

@app.get("/api/health")
async def health():
    return {"status": "ok"}

@app.get("/")
async def root():
    return {"service": "Kronos Web API", "health": "/api/health", "predict": "/api/predict", "warmup": "/api/warmup"}

# ====================== 实时：交易所K线（Binance & OKX） ======================

# 统一频率（小写小时，避免 FutureWarning）
_FREQ = {
    "1m":"T","3m":"3T","5m":"5T","15m":"15T","30m":"30T",
    "1h":"h","2h":"2h","4h":"4h","6h":"6h","8h":"8h","12h":"12h",
    "1d":"D","3d":"3D","1w":"7D","1M":"MS"
}

def _future_index(last_ts: pd.Timestamp, interval: str, n: int) -> pd.DatetimeIndex:
    if interval not in _FREQ:
        raise HTTPException(status_code=400, detail=f"Unsupported interval: {interval}")
    off = pd.tseries.frequencies.to_offset(_FREQ[interval])
    start = last_ts + off
    return pd.date_range(start=start, periods=n, freq=off)

# ---------- Binance ----------
def _binance_klines(symbol: str, interval: str, limit: int) -> pd.DataFrame:
    url = "https://api.binance.com/api/v3/klines"
    r = requests.get(url, params={"symbol": symbol.upper(), "interval": interval, "limit": limit}, timeout=15)
    r.raise_for_status()
    data = r.json()
    if not isinstance(data, list) or not data:
        raise HTTPException(status_code=502, detail=f"Binance response invalid: {data}")
    cols = ["open_time","open","high","low","close","volume","close_time",
            "quote_asset_volume","trades","taker_base","taker_quote","ignore"]
    df = pd.DataFrame(data, columns=cols)
    df["timestamps"] = pd.to_datetime(df["open_time"], unit="ms")
    for c in ["open","high","low","close","volume","quote_asset_volume"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.rename(columns={"quote_asset_volume":"amount"})
    return df[["timestamps","open","high","low","close","volume","amount"]]

# ---------- OKX ----------
_OKX_BAR = {
    "1m":"1m","3m":"3m","5m":"5m","15m":"15m","30m":"30m",
    "1h":"1H","2h":"2H","4h":"4H","6h":"6H","8h":"8H","12h":"12H",
    "1d":"1D","3d":"3D","1w":"1W","1M":"1M"
}
def _to_okx_inst(symbol: str) -> str:
    """把 'okbusdt'/'OKB-USDT'/'OKB/USDT' 等转为 'OKB-USDT'"""
    s = symbol.upper().replace('/', '').replace('_','').replace('-','')
    quotes = ["USDT","USDC","BTC","ETH"]
    for q in quotes:
        if s.endswith(q):
            base = s[:-len(q)]
            return f"{base}-{q}"
    return symbol.upper().replace('_','-').replace('/','-')

def _okx_klines(symbol: str, interval: str, limit: int) -> pd.DataFrame:
    inst = _to_okx_inst(symbol)
    bar = _OKX_BAR.get(interval)
    if not bar:
        raise HTTPException(status_code=400, detail=f"Unsupported interval for OKX: {interval}")
    url = "https://www.okx.com/api/v5/market/candles"
    r = requests.get(url, params={"instId": inst, "bar": bar, "limit": limit}, timeout=15)
    r.raise_for_status()
    j = r.json()
    if j.get("code") != "0" or "data" not in j or not j["data"]:
        raise HTTPException(status_code=502, detail=f"OKX response invalid: {j}")
    rows = list(reversed(j["data"]))  # OKX 从新到旧，翻转为旧到新
    df = pd.DataFrame(rows, columns=[
        "ts","open","high","low","close","vol","volCcy","volCcyQuote","confirm"
    ])
    df["timestamps"] = pd.to_datetime(df["ts"], unit="ms")
    for c in ["open","high","low","close","vol","volCcyQuote"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.rename(columns={"vol":"volume", "volCcyQuote":"amount"})
    return df[["timestamps","open","high","low","close","volume","amount"]]

def _detect_exchange_by_symbol(symbol: str) -> str:
    """OKB* → okx，其它默认 binance"""
    s = symbol.upper().replace('-','').replace('/','').replace('_','')
    return "okx" if s.startswith("OKB") else "binance"

def _fetch_klines(exchange: str, symbol: str, interval: str, limit: int) -> pd.DataFrame:
    if exchange == "binance":
        return _binance_klines(symbol, interval, limit)
    if exchange == "okx":
        return _okx_klines(symbol, interval, limit)
    raise HTTPException(status_code=400, detail=f"Unsupported exchange: {exchange}")

# ---------------- 实时预测（自动判断交易所 + 自动收缩 lookback） ----------------
@app.get("/api/predict_live", response_model=PredictResponse)
def predict_live(
    symbol: str = "BTCUSDT",
    interval: str = "15m",
    model_size: str = "small",
    device: str = "cpu",
    lookback: int = 400,
    pred_len: int = 60,
    exchange: Optional[str] = None,   # 前端不传也行；传了则强制使用
):
    ex = exchange or _detect_exchange_by_symbol(symbol)

    # 不同交易所的单次最大返回限制：Binance=1000，OKX=300
    max_limit = 1000 if ex == "binance" else 300
    need = min(max(lookback, 50), max_limit)

    df = _fetch_klines(ex, symbol, interval, limit=need)
    if len(df) < 1:
        raise HTTPException(status_code=502, detail=f"No candles returned from {ex}")

    # 如果实际拿到的数据仍少于要求的 lookback，自动下调，而不是报错
    if len(df) < lookback:
        lookback = len(df)

    hist = df.tail(lookback).copy()
    cols = ["open","high","low","close"]
    if "volume" in hist.columns: cols.append("volume")
    if "amount" in hist.columns: cols.append("amount")
    x_df = hist[cols]
    x_ts = hist["timestamps"]

    y_idx = _future_index(x_ts.iloc[-1], interval, pred_len)
    y_ts = pd.Series(y_idx, name="timestamps")

    tokenizer, model, max_ctx, mspec = load_model_if_needed(model_size)
    lookback = min(lookback, max_ctx)
    predictor = KronosPredictor(model, tokenizer, device=device, max_context=max_ctx)
    pred_df = predictor.predict(
        df=x_df.tail(lookback),
        x_timestamp=x_ts.tail(lookback),
        y_timestamp=y_ts,
        pred_len=pred_len, T=1.0, top_p=0.9, sample_count=1
    )
    pred_close = pred_df["close"].tolist() if "close" in pred_df.columns else pred_df.iloc[:,0].tolist()
    actual_close = hist["close"].tolist()

    return PredictResponse(
        model_name=mspec["model"], device=device,
        lookback=lookback, pred_len=pred_len,
        x_timestamps=x_ts.astype(str).tolist(),
        y_timestamps=y_ts.astype(str).tolist(),
        actual_close=actual_close,
        pred_close=pred_close,
    )
