""""""

from __future__ import annotations

import datetime
import hashlib
import os
from pathlib import Path
from typing import Dict, List

import torch

from . import base

try:
    import yfinance as yf
    _YF_AVAILABLE = True
except ImportError:
    _YF_AVAILABLE = False


STOCK_CONFIGS = {
    "STOCKS-FandB": {
        "stocks": ['CAG', 'CMG', 'CPB', 'DPZ', 'DRI', 'GIS', 'HRL', 'HSY', 'K', 'KHC', 'LW', 'MCD', 'MDLZ', 'MKC', 'SBUX', 'SJM', 'TSN', 'YUM'],
        "input_stocks": ['CAG', 'CMG', 'CPB', 'DPZ', 'DRI', 'GIS', 'K', 'KHC', 'LW', 'MDLZ', 'MKC', 'SJM', 'TSN', 'YUM'],
        "output_stocks": ['HRL', 'HSY', 'MCD', 'SBUX'],
        "cache_name": "STOCKS-FandB"
    },
    "STOCKS-HEALTH": {
        "stocks": ["ABT", "ABBV", "ABMD", "A", "ALXN", "ALGN", "ABC", "AMGN", "ANTM", "BAX", "BDX", "BIO", "BIIB", "BSX", "BMY", "CAH", "CTLT", "CNC", "CERN", "CI", "COO", "CVS", "DHR", "DVA", "XRAY", "DXCM", "EW", "GILD", "HCA", "HSIC", "HOLX", "HUM", "IDXX", "ILMN", "INCY", "ISRG", "IQV", "JNJ", "LH", "LLY", "MCK", "MDT", "MRK", "MTD", "PKI", "PRGO", "PFE", "DGX", "REGN", "RMD", "STE", "SYK", "TFX", "TMO", "UNH", "UHS", "VAR", "VRTX", "VTRS", "WAT", "WST", "ZBH", "ZTS"],
        "input_stocks": ["ABT", "ABBV", "ABMD", "A", "ALXN", "ALGN", "ABC", "AMGN", "ANTM", "BAX", "BDX", "BIO", "BIIB", "BSX", "BMY", "CAH", "CTLT", "CNC", "CERN", "CI", "COO", "DHR", "DVA", "XRAY", "DXCM", "EW", "GILD", "HCA", "HSIC", "HOLX", "HUM", "IDXX", "ILMN", "INCY", "ISRG", "IQV", "JNJ", "LH", "LLY", "MDT", "MTD", "PKI", "PRGO", "PFE", "DGX", "REGN", "RMD", "STE", "SYK", "TMO", "UHS", "VAR", "VRTX", "VTRS", "WAT", "ZBH", "ZTS"],
        "output_stocks": ["MRK", "WST", "CVS", "MCK", "ABT", "UNH", "TFX"],
        "cache_name": "STOCKS-HEALTH"
    },
}


def _hash_list(items: List[str]) -> str:
    h = hashlib.sha1()
    for item in sorted(items):
        h.update(item.encode("utf-8"))
    return h.hexdigest()[:10]


def _fetch(symbol: str, start: datetime.datetime, end: datetime.datetime):
    ticker = yf.Ticker(symbol)
    return ticker.history(start=start, end=end)


def download_and_prepare(
    dataset_name: str = "STOCKS-FandB",
    root: str | Path | None = None,
    download: bool = True,
    start_date: datetime.datetime = datetime.datetime(2000, 6, 1),
    end_date: datetime.datetime = datetime.datetime(2021, 2, 28),
) -> Dict:
    """\ndataset_name: str\nroot: str | Path | None\ndownload: bool\nstart_date: datetime.datetime\nend_date: datetime.datetime\n    """
    if dataset_name not in STOCK_CONFIGS:
        raise ValueError(f"Unknown dataset {dataset_name}. Options: {list(STOCK_CONFIGS)}")
    if not _YF_AVAILABLE:
        raise ImportError("yfinance not installed. Please `pip install yfinance` to download stocks.")

    cfg = STOCK_CONFIGS[dataset_name]
    data_root = base.ensure_dir(root or base.default_data_root() / "stocks" / dataset_name)
    cache_dir = base.ensure_dir(data_root / "cached_data")
    cache_file = cache_dir / f"{cfg['cache_name']}.pt"

    if cache_file.exists() or not download:
        return base.build_metadata(dataset_name, data_root, {"cache": cache_file})

    stocks = cfg["stocks"]
    data = []
    valid = []
    for sym in stocks:
        df = _fetch(sym, start_date, end_date)
        if len(df) == 0:
            continue
        df = df.reset_index()
        df.insert(0, "Symbol", sym)
        data.append(df)
        valid.append(sym)

    if not data:
        raise RuntimeError("No stock data fetched; check network/API limits.")

    import pandas as pd

    df_all = pd.concat(data, ignore_index=True)
    df_all = df_all.sort_values(by=["Date", "Symbol"])

    input_stocks = [s for s in cfg["input_stocks"] if s in valid]
    output_stocks = [s for s in cfg["output_stocks"] if s in valid]

    symbol_to_idx = {s: i for i, s in enumerate(df_all["Symbol"].unique())}
    input_idx = [symbol_to_idx[s] for s in input_stocks]
    output_idx = [symbol_to_idx[s] for s in output_stocks]

    X = torch.tensor(list(df_all["Open"])).view(-1, len(valid))
    RX = torch.log(X[1:] / X[:-1])
    window_size = 500
    val_split = 3200
    test_split = 3700

    RX = RX / torch.std(RX[: window_size + val_split])
    Y = RX[window_size:, output_idx]
    Y = Y / torch.std(Y[:val_split])

    X_seq = [RX[i : i + window_size, input_idx].reshape(1, window_size, -1) for i in range(len(RX) - window_size)]
    X_seq = torch.cat(X_seq)

    torch.save(
        {
            "X": X_seq,
            "Y": Y,
            "valid_stocks": valid,
            "input_stocks": input_stocks,
            "output_stocks": output_stocks,
            "window_size": window_size,
            "val_split": val_split,
            "test_split": test_split,
        },
        cache_file,
    )

    return base.build_metadata(dataset_name, data_root, {"cache": cache_file})


__all__ = ["download_and_prepare", "STOCK_CONFIGS"]
