""""""
from __future__ import annotations

import argparse
from typing import Dict, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset

from son_goku import SonGokuScheduler
from experiments.train_utils import MultiTaskTrainer, TaskSpec
from experiments.collection import base as collection_base


class MMIMDbDataset(Dataset):
    def __init__(self, split: str, tokenizer, max_length: int = 256, cache_dir: str | None = None):
        self.ds = load_dataset("mm_imdb", split=split, cache_dir=cache_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample = self.ds[idx]
        text = sample["text"]
        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        enc = {k: v.squeeze(0) for k, v in enc.items()}

        genres = torch.tensor(sample["label"], dtype=torch.float32)
        year = int(sample.get("year", 2000))
        decade = max(0, min(9, (year // 10) - 191))  # compress into leq 10 bins
        title = sample.get("title", "")
        vowel = 1 if title[:1].lower() in "aeiou" else 0

        targets = {
            "genre": genres,
            "decade": torch.tensor(decade, dtype=torch.long),
            "vowel": torch.tensor(vowel, dtype=torch.long),
        }
        return enc, targets


class MMIMDbModel(nn.Module):
    def __init__(self, encoder_name: str = "distilbert-base-uncased", hidden_dim: int = 768, num_genres: int = 23):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.dropout = nn.Dropout(0.1)
        self.heads = nn.ModuleDict(
            {
                "genre": nn.Linear(hidden_dim, num_genres),
                "decade": nn.Linear(hidden_dim, 10),
                "vowel": nn.Linear(hidden_dim, 2),
            }
        )

    def encode(self, enc_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        out = self.encoder(**enc_inputs)
        cls = out.last_hidden_state[:, 0]
        return self.dropout(cls)

    def shared_parameters(self):
        return self.encoder.parameters()


def multilabel_accuracy(logits: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
    preds = (logits.sigmoid() > threshold).float()
    correct = (preds == target).float().mean()
    return correct


def main(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_root = args.data_root or collection_base.default_data_root() / "mmimdb"
    cache_dir = str(data_root / "hf_cache")

    tokenizer = AutoTokenizer.from_pretrained(args.encoder)
    train_ds = MMIMDbDataset("train", tokenizer, max_length=args.max_length, cache_dir=cache_dir)
    val_ds = MMIMDbDataset("validation", tokenizer, max_length=args.max_length, cache_dir=cache_dir)
    test_ds = MMIMDbDataset("test", tokenizer, max_length=args.max_length, cache_dir=cache_dir)

    def collate(batch):
        encs, targets = zip(*batch)
        enc_batch = {k: torch.stack([e[k] for e in encs]) for k in encs[0]}
        targets_batch = {k: torch.stack([t[k] for t in targets]) for k in targets[0]}
        return enc_batch, targets_batch

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate)

    model = MMIMDbModel(encoder_name=args.encoder, hidden_dim=args.hidden_dim)

    def fwd_genre(m, batch, device):
        encs, targets = batch
        encs = {k: v.to(device) for k, v in encs.items()}
        feats = m.encode(encs)
        logits = m.heads["genre"](feats)
        return logits, targets["genre"].to(device)

    def fwd_decade(m, batch, device):
        encs, targets = batch
        encs = {k: v.to(device) for k, v in encs.items()}
        feats = m.encode(encs)
        logits = m.heads["decade"](feats)
        return logits, targets["decade"].to(device)

    def fwd_vowel(m, batch, device):
        encs, targets = batch
        encs = {k: v.to(device) for k, v in encs.items()}
        feats = m.encode(encs)
        logits = m.heads["vowel"](feats)
        return logits, targets["vowel"].to(device)

    tasks = (
        TaskSpec("genre", fwd_genre, nn.BCEWithLogitsLoss(), multilabel_accuracy),
        TaskSpec("decade", fwd_decade, nn.CrossEntropyLoss(), lambda p, t: (p.argmax(dim=1) == t).float().mean()),
        TaskSpec("title_initial", fwd_vowel, nn.CrossEntropyLoss(), lambda p, t: (p.argmax(dim=1) == t).float().mean()),
    )

    shared_dim = sum(p.numel() for p in model.shared_parameters())
    scheduler = SonGokuScheduler(
        num_tasks=len(tasks),
        grad_dim=shared_dim,
        refresh_period=args.refresh_period,
        tau_init=1.0,
        tau_target=0.3,
        warmup_steps=args.warmup_steps,
        anneal_rate=5e-4,
        sketch_dim=args.sketch_dim,
        random_state=args.seed,
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    trainer = MultiTaskTrainer(model, tasks, scheduler, optimizer, device)

    for epoch in range(args.epochs):
        train_metrics = trainer.train_epoch(train_loader, epoch)
        val_metrics = trainer.evaluate(val_loader)
        test_metrics = trainer.evaluate(test_loader)
        print(f"[Epoch {epoch}] train={train_metrics}  val={val_metrics}  test={test_metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SON-GOKU on MM-IMDb")
    parser.add_argument("--data-root", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--num-workers", type=int, default=2)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--encoder", type=str, default="distilbert-base-uncased")
    parser.add_argument("--hidden-dim", type=int, default=768)
    parser.add_argument("--max-length", type=int, default=256)
    parser.add_argument("--refresh-period", type=int, default=16)
    parser.add_argument("--warmup-steps", type=int, default=200)
    parser.add_argument("--sketch-dim", type=int, default=64)
    parser.add_argument("--seed", type=int, default=0)
    main(parser.parse_args())
