""""""
from __future__ import annotations

import argparse

import torch
from torch import nn
from torch.utils.data import DataLoader

from son_goku import SonGokuScheduler
from experiments.train_utils import MultiTaskTrainer, TaskSpec, _move_to_device
from experiments.preprocessing import avmnist as avmnist_prep
from experiments.collection import base as collection_base


class AVMNISTModel(nn.Module):
    def __init__(self, img_feat: int = 128, audio_feat: int = 128, fused: int = 256):
        super().__init__()
        self.img_encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, img_feat),
            nn.ReLU(inplace=True),
        )
        self.audio_encoder = nn.Sequential(
            nn.Conv1d(1, 16, 5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv1d(16, 32, 5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool1d(16),
            nn.Flatten(),
            nn.Linear(32 * 16, audio_feat),
            nn.ReLU(inplace=True),
        )
        self.fuse = nn.Sequential(
            nn.Linear(img_feat + audio_feat, fused),
            nn.ReLU(inplace=True),
        )
        self.heads = nn.ModuleDict(
            {
                "digit": nn.Linear(fused, 10),
                "parity": nn.Linear(fused, 2),
            }
        )

    def encode(self, batch) -> torch.Tensor:
        img = batch["image"]
        audio = batch["audio"]
        img_feat = self.img_encoder(img)
        audio_feat = self.audio_encoder(audio)
        fused = torch.cat([img_feat, audio_feat], dim=1)
        return self.fuse(fused)

    def shared_parameters(self):
        return list(self.img_encoder.parameters()) + list(self.audio_encoder.parameters()) + list(self.fuse.parameters())


def accuracy(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return (preds.argmax(dim=1) == target).float().mean()


def main(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_root = args.data_root or collection_base.default_data_root() / "avmnist"

    train_loader, test_loader = avmnist_prep.create_dataloaders(root=str(data_root), batch_size=args.batch_size, num_workers=args.num_workers, download=args.download)

    model = AVMNISTModel()

    def _unpack(batch):
        if isinstance(batch, (list, tuple)):
            data, labels = batch
        else:
            data = batch
            labels = batch["label"]
        return data, labels

    def fwd_digit(m, batch, device):
        data, labels = _unpack(batch)
        data = _move_to_device(data, device)
        feats = m.encode(data)
        logits = m.heads["digit"](feats)
        return logits, labels.to(device)

    def fwd_parity(m, batch, device):
        data, labels = _unpack(batch)
        data = _move_to_device(data, device)
        feats = m.encode(data)
        logits = m.heads["parity"](feats)
        target = (labels % 2).long()
        return logits, target.to(device)

    tasks = (
        TaskSpec("digit", fwd_digit, nn.CrossEntropyLoss(), accuracy),
        TaskSpec("parity", fwd_parity, nn.CrossEntropyLoss(), accuracy),
    )

    shared_dim = sum(p.numel() for p in model.shared_parameters())
    scheduler = SonGokuScheduler(
        num_tasks=len(tasks),
        grad_dim=shared_dim,
        refresh_period=args.refresh_period,
        tau_init=1.0,
        tau_target=0.3,
        warmup_steps=args.warmup_steps,
        anneal_rate=5e-4,
        sketch_dim=args.sketch_dim,
        random_state=args.seed,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    trainer = MultiTaskTrainer(model, tasks, scheduler, optimizer, device)

    for epoch in range(args.epochs):
        train_metrics = trainer.train_epoch(train_loader, epoch)
        test_metrics = trainer.evaluate(test_loader)
        print(f"[Epoch {epoch}] train={train_metrics}  test={test_metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SON-GOKU on AV-MNIST")
    parser.add_argument("--data-root", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--refresh-period", type=int, default=16)
    parser.add_argument("--warmup-steps", type=int, default=500)
    parser.add_argument("--sketch-dim", type=int, default=64)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--download", action="store_true")
    main(parser.parse_args())
