import copy
import os

import torch
from tqdm import tqdm  # type: ignore

from models.deep_ensemble import ModelConv as Model
from utils import (Args, ClassificationStats, ConvArgs, ModelSaveDict,
                   get_conv_args)


def make_adv(
    x: torch.Tensor, y: torch.Tensor, loss: torch.Tensor, ft_ranges: torch.Tensor,
) -> torch.Tensor:
    grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
    return x + (0.01 * ft_ranges) * torch.sign(grad)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = [
        Model(
            (args.x_dim, args.x_dim), args.h_dim, args.y_dim, get_conv_args(args),
        ).to(args.device)
        for _ in range(args.model_n)
    ]

    params = list(models[0].parameters())
    for m in models[1:]:
        params += list(m.parameters())

    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

    ft_min, ft_max = (
        torch.zeros(args.in_ch, args.x_dim, args.x_dim),
        torch.zeros(args.in_ch, args.x_dim, args.x_dim),
    )
    for (x, _) in args.train_set:
        mn, _ = x.min(dim=0)
        mx, _ = x.max(dim=0)

        ft_min = torch.min(ft_min, mn)
        ft_max = torch.max(ft_max, mx)

    ft_ranges = (ft_max - ft_min).to(args.device)

    criterion = torch.nn.NLLLoss()

    # stat_log = tqdm(total=0, bar_format="train: {desc}", position=1, leave=False)
    best_stats: ClassificationStats = ClassificationStats(args.test_set.dataset, args)
    model_dict: ModelSaveDict = {}

    for epoch in range(args.epochs):
        [m.train() for m in models]
        for i, data in enumerate(zip(*[args.train_set for _ in range(args.model_n)])):
            xs = torch.stack([x.to(args.device, non_blocking=True) for (x, _) in data])
            ys = torch.stack([y.to(args.device, non_blocking=True) for (_, y) in data])
            # idx = torch.stack([idx.to(args.device) for (_, _, idx) in data])
            xs.requires_grad_(True)

            yhats = torch.stack([m(x) for (m, x) in zip(models, xs)])
            loss = criterion(
                torch.log_softmax(yhats, dim=2).view(-1, yhats.size(2)), ys.view(-1)
            )

            x_adv = make_adv(xs, ys, loss, ft_ranges)
            yhats_adv = torch.stack([m(x) for (m, x) in zip(models, x_adv)])

            loss += criterion(
                torch.log_softmax(yhats_adv, dim=2).view(-1, yhats.size(2)), ys.view(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            [m.eval() for m in models]
            t = ClassificationStats(args.test_set.dataset, args)
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = (
                        x.to(args.device, non_blocking=True),
                        y.to(args.device, non_blocking=True),
                    )

                    yhat = torch.stack([m(x) for m in models]).mean(dim=0)
                    t.set_multiclass(y, yhat)

            t.calc_stats()
            print(f"epoch: {epoch} acc: {t.accuracy:.4f}")
            if t.accuracy > best_stats.accuracy:
                best_stats = t

    for i, m in enumerate(models):
        path = os.path.join(
            args.get_model_dir(run, classification=True), f"model-{i}.pt"
        )
        model_dict[path] = copy.deepcopy(m).cpu().state_dict()

    return model_dict


def eval(args: Args, run: int, model_dict: ModelSaveDict) -> ClassificationStats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = []
    for i in range(args.model_n):
        m = Model(
            (args.x_dim, args.x_dim), args.h_dim, args.y_dim, get_conv_args(args),
        ).to(args.device)

        path = os.path.join(
            args.get_model_dir(run, classification=True), f"model-{i}.pt"
        )
        m.load_state_dict(model_dict[path])
        m.eval()
        models.append(m.to(args.device))

    t = ClassificationStats(args.test_set.dataset, args)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = (
                x.to(args.device, non_blocking=True),
                y.to(args.device, non_blocking=True),
            )

            yhat = torch.stack([m(x) for m in models]).mean(dim=0)
            t.set_multiclass(y, yhat)

    t.calc_stats()
    return t
