""""""
from __future__ import annotations

import argparse
import random
from typing import Dict, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from torchvision.transforms import functional as F

from son_goku import SonGokuScheduler
from experiments.train_utils import MultiTaskTrainer, TaskSpec
from experiments.collection import base as collection_base


# ---------------------------- data ---------------------------- #

class CifarMultiTaskDataset(Dataset):
    def __init__(self, train: bool, root: str, download: bool = True):
        self.base = datasets.CIFAR10(root=root, train=train, download=download)
        self.train = train
        self.clean_tf = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
            ]
        )
        self.corruption_transforms = [
            lambda img: F.adjust_brightness(img, 0.5 + random.random()),
            lambda img: F.adjust_contrast(img, 0.5 + random.random()),
            lambda img: F.gaussian_blur(img, kernel_size=3),
            lambda img: F.posterize(img, bits=3),
            lambda img: F.solarize(img, threshold=128),
        ]

    def __len__(self):
        return len(self.base)

    def _quadrant_label(self, pil_img) -> int:
        gray = np.array(pil_img.convert("L"))
        h, w = gray.shape
        mid_h, mid_w = h // 2, w // 2
        quads = [
            gray[:mid_h, :mid_w].mean(),
            gray[:mid_h, mid_w:].mean(),
            gray[mid_h:, :mid_w].mean(),
            gray[mid_h:, mid_w:].mean(),
        ]
        return int(np.argmax(quads))

    def _texture_label(self, pil_img) -> int:
        gray = np.array(pil_img.convert("L")).astype(np.float32) / 255.0
        # variance of Sobel magnitude quantized into 8 bins
        gx = np.gradient(gray, axis=0)
        gy = np.gradient(gray, axis=1)
        mag = np.sqrt(gx ** 2 + gy ** 2)
        score = float(np.var(mag))
        bins = np.linspace(0, 0.05, 9)
        return int(np.digitize(score, bins) - 1)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        img, label = self.base[idx]

        quadrant_label = torch.tensor(self._quadrant_label(img), dtype=torch.long)
        texture_label = torch.tensor(self._texture_label(img), dtype=torch.long)

        # rotation task
        angle_choices = [0, 90, 180, 270]
        angle_idx = random.randrange(len(angle_choices))
        rot_img = F.rotate(img, angle_choices[angle_idx])

        # corruption task
        corr_idx = random.randrange(len(self.corruption_transforms))
        corr_img = self.corruption_transforms[corr_idx](img)

        sample = {
            "clean": self.clean_tf(img),
            "class_label": torch.tensor(label, dtype=torch.long),
            "quadrant_label": quadrant_label,
            "texture_label": texture_label,
            "rotated": self.clean_tf(rot_img),
            "rotation_label": torch.tensor(angle_idx, dtype=torch.long),
            "corrupted": self.clean_tf(corr_img),
            "corruption_label": torch.tensor(corr_idx, dtype=torch.long),
        }
        return sample


# --------------------------- models --------------------------- #

class CifarMTLModel(nn.Module):
    def __init__(self, feat_dim: int = 512):
        super().__init__()
        backbone = models.resnet18(weights=None)
        backbone.fc = nn.Identity()
        self.backbone = backbone
        self.heads = nn.ModuleDict(
            {
                "class": nn.Linear(feat_dim, 10),
                "quadrant": nn.Linear(feat_dim, 4),
                "texture": nn.Linear(feat_dim, 8),
                "corruption": nn.Linear(feat_dim, 5),
                "rotation": nn.Linear(feat_dim, 4),
            }
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)

    def shared_parameters(self):
        return self.backbone.parameters()


# --------------------------- helpers -------------------------- #

def accuracy(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return (preds.argmax(dim=1) == target).float().mean()


def build_tasks(model: CifarMTLModel) -> Tuple[TaskSpec, ...]:
    def make_forward(img_key: str, label_key: str, head_key: str):
        def _fn(m: nn.Module, batch, device):
            feats = m.encode(batch[img_key])
            logits = m.heads[head_key](feats)
            target = batch[label_key]
            return logits, target

        return _fn

    return (
        TaskSpec("class", make_forward("clean", "class_label", "class"), nn.CrossEntropyLoss(), accuracy),
        TaskSpec("quadrant", make_forward("clean", "quadrant_label", "quadrant"), nn.CrossEntropyLoss(), accuracy),
        TaskSpec("texture", make_forward("clean", "texture_label", "texture"), nn.CrossEntropyLoss(), accuracy),
        TaskSpec("corruption", make_forward("corrupted", "corruption_label", "corruption"), nn.CrossEntropyLoss(), accuracy),
        TaskSpec("rotation", make_forward("rotated", "rotation_label", "rotation"), nn.CrossEntropyLoss(), accuracy),
    )


# ----------------------------- main --------------------------- #

def main(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_root = args.data_root or collection_base.default_data_root() / "cifar10"
    train_ds = CifarMultiTaskDataset(train=True, root=str(data_root), download=args.download)
    test_ds = CifarMultiTaskDataset(train=False, root=str(data_root), download=args.download)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    model = CifarMTLModel()
    tasks = build_tasks(model)

    shared_dim = sum(p.numel() for p in model.shared_parameters())
    scheduler = SonGokuScheduler(
        num_tasks=len(tasks),
        grad_dim=shared_dim,
        refresh_period=args.refresh_period,
        beta=0.9,
        tau_init=1.0,
        tau_target=0.3,
        warmup_steps=args.warmup_steps,
        anneal_rate=5e-4,
        sketch_dim=args.sketch_dim,
        random_state=args.seed,
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    trainer = MultiTaskTrainer(model, tasks, scheduler, optimizer, device, grad_clip=5.0)

    for epoch in range(args.epochs):
        train_metrics = trainer.train_epoch(train_loader, epoch)
        test_metrics = trainer.evaluate(test_loader)
        print(f"[Epoch {epoch}] train={train_metrics}  test={test_metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SON-GOKU on CIFAR-10 with auxiliary tasks.")
    parser.add_argument("--data-root", type=str, default=None, help="Optional data directory (defaults to ./data/cifar10).")
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--refresh-period", type=int, default=32)
    parser.add_argument("--warmup-steps", type=int, default=1000)
    parser.add_argument("--sketch-dim", type=int, default=128)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--download", action="store_true", help="Allow dataset download.")
    main(parser.parse_args())
