""""""
from __future__ import annotations

import argparse
from typing import Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader

from son_goku import SonGokuScheduler
from experiments.train_utils import MultiTaskTrainer, TaskSpec
from experiments.preprocessing import stocks as stocks_prep
from experiments.collection import base as collection_base


class StocksModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128, output_dim: int = 4):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
        )
        self.lstm = nn.LSTM(input_size=64, hidden_size=hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.return_head = nn.Linear(hidden_dim * 2, output_dim)
        self.vol_head = nn.Linear(hidden_dim * 2, 1)
        self.mean_head = nn.Linear(hidden_dim * 2, 1)
        self.recon_head = nn.Linear(hidden_dim * 2, 1)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        x = x.permute(0, 2, 1)  # B, C, T
        conv_out = self.conv(x)  # B, 64, T
        conv_out = conv_out.permute(0, 2, 1)  # B, T, 64
        lstm_out, _ = self.lstm(conv_out)
        feats = lstm_out[:, -1]  # last timestep features
        return feats

    def shared_parameters(self):
        return list(self.conv.parameters()) + list(self.lstm.parameters())


def main(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_root = args.data_root or collection_base.default_data_root() / "stocks"
    train_loader, val_loader, test_loaders = stocks_prep.create_dataloaders(
        dataset_name=args.dataset,
        root=str(data_root),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        modality_first=False,
        download=args.download,
    )
    test_loader = list(test_loaders["timeseries"])[0]

    # Dims
    sample_x, sample_y = next(iter(train_loader))
    input_dim = sample_x.shape[-1]
    output_dim = sample_y.shape[-1]

    model = StocksModel(input_dim=input_dim, hidden_dim=args.hidden_dim, output_dim=output_dim).to(device)

    def fwd_return(m, batch, device):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        feats = m.forward_features(x)
        pred = m.return_head(feats)
        return pred, y

    def fwd_vol(m, batch, device):
        x, y = batch
        x = x.to(device)
        feats = m.forward_features(x)
        target = x[:, -5:, :].std(dim=1).mean(dim=1, keepdim=True)
        pred = m.vol_head(feats)
        return pred, target

    def fwd_mean(m, batch, device):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        feats = m.forward_features(x)
        target = y.mean(dim=1, keepdim=True)
        pred = m.mean_head(feats)
        return pred, target

    def fwd_recon(m, batch, device):
        x, y = batch
        x = x.to(device)
        feats = m.forward_features(x)
        # Reconstruct last open price of first input stock
        target = x[:, -1:, 0]
        pred = m.recon_head(feats).squeeze(-1)
        return pred, target

    mae = lambda p, t: torch.mean(torch.abs(p - t))

    tasks = (
        TaskSpec("returns", fwd_return, nn.L1Loss(), mae),
        TaskSpec("volatility", fwd_vol, nn.L1Loss(), mae),
        TaskSpec("sector_mean", fwd_mean, nn.L1Loss(), mae),
        TaskSpec("open_recon", fwd_recon, nn.L1Loss(), mae),
    )

    shared_dim = sum(p.numel() for p in model.shared_parameters())
    scheduler = SonGokuScheduler(
        num_tasks=len(tasks),
        grad_dim=shared_dim,
        refresh_period=args.refresh_period,
        tau_init=1.0,
        tau_target=0.3,
        warmup_steps=args.warmup_steps,
        anneal_rate=5e-4,
        sketch_dim=args.sketch_dim,
        random_state=args.seed,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    trainer = MultiTaskTrainer(model, tasks, scheduler, optimizer, device)

    for epoch in range(args.epochs):
        train_metrics = trainer.train_epoch(train_loader, epoch)
        val_metrics = trainer.evaluate(val_loader)
        test_metrics = trainer.evaluate(test_loader)
        print(f"[Epoch {epoch}] train={train_metrics}  val={val_metrics}  test={test_metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SON-GOKU on STOCKS datasets")
    parser.add_argument("--dataset", type=str, default="STOCKS-FandB", choices=["STOCKS-FandB", "STOCKS-HEALTH"])
    parser.add_argument("--data-root", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--hidden-dim", type=int, default=128)
    parser.add_argument("--refresh-period", type=int, default=16)
    parser.add_argument("--warmup-steps", type=int, default=200)
    parser.add_argument("--sketch-dim", type=int, default=64)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--download", action="store_true")
    main(parser.parse_args())
