import os
from collections import defaultdict
from dataclasses import dataclass
from functools import partial

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from wandb.wandb_run import Run

from .dataset import Batch, TSPDataset
from .model import Model
from .solve import solve
from .utils import solution_cost


@dataclass
class Trainer:
    device: torch.device
    evaluation_batch_size: int
    evaluation_every: int
    evaluation_iters: int
    training_batch_size: int
    training_iters: int
    world_size: int

    def train(
        self,
        model: Model | DDP,
        optimizer: optim.Optimizer,
        scheduler: optim.lr_scheduler.LRScheduler,
        train_dataset: TSPDataset,
        datasets: list[TSPDataset],
        initial_iter: int = 0,
        logger: Run | None = None,
    ):
        if logger is not None:
            logger.summary["experiment-dir"] = os.getcwd()
            logger.summary["n-params"] = sum(len(p.flatten()) for p in model.parameters())
            logger.summary["training-size"] = len(train_dataset)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.training_batch_size,
            sampler=RandomSampler(
                train_dataset,
                replacement=True,
                num_samples=(self.training_iters - initial_iter) * self.training_batch_size,
            ),
            collate_fn=partial(Batch.from_samples, device=self.device),
            pin_memory=True,
        )
        for iter_id, batch in tqdm(
            zip(range(initial_iter, self.training_iters), train_dataloader),
            desc="Training iters",
            unit=" sample",
            unit_scale=self.world_size * self.training_batch_size,
            total=self.training_iters,
            initial=initial_iter,
            disable=logger is None,
        ):
            if logger is not None and iter_id % self.evaluation_every == 0:
                metrics = {"train": self.eval(model, train_dataset)}
                metrics |= {d.name: self.eval(model, d) for d in datasets}
                metrics |= {"lr": scheduler.get_last_lr()[0]}
                logger.log(
                    metrics,
                    step=self.world_size * iter_id * self.training_batch_size,
                    commit=True,
                )
                Trainer.checkpoint(model, optimizer, scheduler, logger, iter_id)

            model.train()
            optimizer.zero_grad()
            logits = model(batch.x, batch.m, batch.o, batch.e)
            loss = F.cross_entropy(logits, batch.l, reduction="mean")
            loss.backward()
            optimizer.step()
            scheduler.step()

        # Final evaluation.
        if logger is not None:
            metrics = {"train": self.eval(model, train_dataset)}
            metrics |= {d.name: self.eval(model, d) for d in datasets}
            metrics |= {"lr": scheduler.get_last_lr()[0]}
            logger.log(metrics, step=self.training_iters * self.training_batch_size, commit=True)
            Trainer.checkpoint(model, optimizer, scheduler, logger, self.training_iters)

    @torch.inference_mode(True)
    def eval(self, model: Model | DDP, dataset: TSPDataset) -> dict[str, float]:
        model.eval()
        metrics = defaultdict(list)
        total_samples = min(len(dataset), self.evaluation_iters * self.evaluation_batch_size)
        dataloader = DataLoader(
            dataset,
            batch_size=self.evaluation_batch_size,
            sampler=RandomSampler(
                dataset,
                replacement=False,
                num_samples=total_samples,
            ),
            collate_fn=partial(Batch.from_samples, device=self.device),
        )

        for batch in tqdm(
            dataloader,
            desc=f"Evaluating {dataset.name}",
            unit=" sample",
            unit_scale=self.evaluation_batch_size,
            leave=False,
        ):
            for name, values in self.batch_metrics(model, batch).items():
                metrics[name].append(values)

        for name, values in metrics.items():
            metrics[name] = torch.concat(values, dim=0).float().mean().cpu().item()

        return metrics

    @torch.compile(dynamic=True)
    def batch_metrics(self, model: Model | DDP, batch: Batch) -> dict[str, Tensor]:
        metrics = dict()

        logits = model(batch.x, batch.m, batch.o, batch.e)
        metrics["loss"] = F.cross_entropy(logits, batch.l, reduction="none")
        metrics["accuracy"] = logits.argmax(dim=1) == batch.l

        x, s = batch.x[:, : batch.n], batch.s[:, : batch.n]  # Remove padding!
        opt = torch.vmap(solution_cost)(x, s)

        # Randomly permute the cities to make sure the optimal solution isn't the trivial "predict next
        # city in the sequence" solution.
        x_ = x[:, torch.randperm(batch.n)]
        s_ = solve(x_, model)
        cost = torch.vmap(solution_cost)(x_, s_)

        metrics["opt-gap"] = 100 * (cost - opt) / opt

        return metrics

    @staticmethod
    def checkpoint(
        model: Model | DDP,
        optimizer: optim.Optimizer,
        scheduler: optim.lr_scheduler.LRScheduler,
        logger: Run,
        iter_id: int,
    ):
        if type(model) is DDP:
            model = model.module

        torch.save(
            {
                "model-state": model.state_dict(),
                "optimizer-state": optimizer.state_dict(),
                "scheduler-state": scheduler.state_dict(),
                "run-id": logger.id,
                "iter-id": iter_id,
            },
            "checkpoint.pth",
        )
