""""""
from __future__ import annotations

import argparse
from typing import Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn.functional as F

from son_goku import SonGokuScheduler
from experiments.train_utils import MultiTaskTrainer, TaskSpec, _move_to_device
from experiments.preprocessing import nyuv2 as nyuv2_prep
from experiments.collection import base as collection_base


class NYUv2Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet18(weights=None)
        self.stem = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
        )
        self.out_channels = 512

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.stem(x)

    def shared_parameters(self):
        return self.parameters()


class NYUv2Model(nn.Module):
    def __init__(self, num_classes: int = 14):
        super().__init__()
        self.backbone = NYUv2Backbone()
        self.seg_head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.backbone.out_channels, num_classes),
        )
        self.depth_head = nn.Sequential(
            nn.ConvTranspose2d(self.backbone.out_channels, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, 3, stride=2, padding=1, output_padding=1),
        )
        self.normal_head = nn.Sequential(
            nn.ConvTranspose2d(self.backbone.out_channels, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1),
        )
        self.color_head = nn.Linear(self.backbone.out_channels, 3)  # cool / neutral / warm

    def shared_parameters(self):
        return self.backbone.shared_parameters()

    def encode(self, rgb: torch.Tensor) -> torch.Tensor:
        return self.backbone(rgb)


def compute_normals(depth: torch.Tensor) -> torch.Tensor:
    """\ndepth: torch.Tensor\n    """
    dzdx = depth[:, :, :, 1:] - depth[:, :, :, :-1]
    dzdy = depth[:, :, 1:, :] - depth[:, :, :-1, :]
    dzdx = F.pad(dzdx, (0, 1, 0, 0))
    dzdy = F.pad(dzdy, (0, 0, 0, 1))
    nx = -dzdx
    ny = -dzdy
    nz = torch.ones_like(depth)
    n = torch.cat([nx, ny, nz], dim=1)
    n = F.normalize(n, dim=1, eps=1e-6)
    return n


def color_temperature_label(rgb: torch.Tensor) -> torch.Tensor:
    """\nrgb: torch.Tensor\n    """
    mean_rgb = rgb.mean(dim=[2, 3])
    r, g, b = mean_rgb[:, 0], mean_rgb[:, 1], mean_rgb[:, 2]
    temp_score = (r - b)
    bins = torch.tensor([-0.05, 0.05], device=rgb.device)
    labels = torch.bucketize(temp_score, bins)
    return labels  # 0 cool,1 neutral,2 warm


def main(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_root = args.data_root or collection_base.default_data_root() / "nyuv2"
    train_loader, val_loader, test_loader = nyuv2_prep.create_dataloaders(
        root=str(data_root),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        download=args.download,
        image_size=args.image_size,
    )

    model = NYUv2Model(num_classes=args.num_classes).to(device)

    def fwd_seg(m, batch, device):
        data, labels = batch if isinstance(batch, (list, tuple)) else (batch, batch.get("label"))
        rgb = _move_to_device(data["rgb"], device)
        feats = m.encode(rgb)
        logits = m.seg_head(feats)
        target = labels.to(device) if not isinstance(labels, torch.Tensor) else labels.to(device)
        return logits, target

    def fwd_depth(m, batch, device):
        data, _ = batch if isinstance(batch, (list, tuple)) else (batch, batch.get("label"))
        rgb = _move_to_device(data["rgb"], device)
        depth_gt = _move_to_device(data["depth"], device)
        feats = m.encode(rgb)
        pred = m.depth_head(feats)
        depth_resized = F.interpolate(depth_gt, size=pred.shape[-2:], mode="bilinear", align_corners=False)
        return pred, depth_resized

    def fwd_normal(m, batch, device):
        data, _ = batch if isinstance(batch, (list, tuple)) else (batch, batch.get("label"))
        rgb = _move_to_device(data["rgb"], device)
        depth_gt = _move_to_device(data["depth"], device)
        feats = m.encode(rgb)
        pred = m.normal_head(feats)
        depth_resized = F.interpolate(depth_gt, size=pred.shape[-2:], mode="bilinear", align_corners=False)
        normals = compute_normals(depth_resized)
        return pred, normals

    def fwd_color(m, batch, device):
        data, _ = batch if isinstance(batch, (list, tuple)) else (batch, batch.get("label"))
        rgb = _move_to_device(data["rgb"], device)
        feats = m.encode(rgb)
        pooled = F.adaptive_avg_pool2d(feats, 1).flatten(1)
        logits = m.color_head(pooled)
        target = color_temperature_label(rgb)
        return logits, target

    tasks = (
        TaskSpec("segmentation", fwd_seg, nn.CrossEntropyLoss(), lambda p, t: (p.argmax(1) == t).float().mean()),
        TaskSpec("depth", fwd_depth, nn.L1Loss(), lambda p, t: torch.mean(torch.abs(p - t))),
        TaskSpec("normals", fwd_normal, nn.L1Loss(), None),
        TaskSpec("color_temp", fwd_color, nn.CrossEntropyLoss(), lambda p, t: (p.argmax(1) == t).float().mean()),
    )

    shared_dim = sum(p.numel() for p in model.shared_parameters())
    scheduler = SonGokuScheduler(
        num_tasks=len(tasks),
        grad_dim=shared_dim,
        refresh_period=args.refresh_period,
        tau_init=1.0,
        tau_target=0.3,
        warmup_steps=args.warmup_steps,
        anneal_rate=5e-4,
        sketch_dim=args.sketch_dim,
        random_state=args.seed,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    trainer = MultiTaskTrainer(model, tasks, scheduler, optimizer, device)

    for epoch in range(args.epochs):
        train_metrics = trainer.train_epoch(train_loader, epoch)
        val_metrics = trainer.evaluate(val_loader)
        test_metrics = trainer.evaluate(test_loader)
        print(f"[Epoch {epoch}] train={train_metrics}  val={val_metrics}  test={test_metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SON-GOKU on NYUv2 tasks")
    parser.add_argument("--data-root", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--num-workers", type=int, default=2)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num-classes", type=int, default=14)
    parser.add_argument("--image-size", type=int, default=224)
    parser.add_argument("--refresh-period", type=int, default=16)
    parser.add_argument("--warmup-steps", type=int, default=500)
    parser.add_argument("--sketch-dim", type=int, default=64)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--download", action="store_true")
    main(parser.parse_args())
