import copy
import os

import torch

from models.deep_ensemble import ModelConv as Model
from models.set_encoder import LatentPerturber
from utils import (Args, ClassificationStatsTemp, ModelSaveDict, get_conv_args,
                   softmax_entropy)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    conv_args = get_conv_args(args)
    perturber = LatentPerturber(conv_args[-1][1]).to(args.device)
    phi_opt = torch.optim.Adam(perturber.parameters(), lr=args.beta_lr)

    models = [
        Model((args.x_dim, args.x_dim), args.h_dim, args.y_dim, conv_args,).to(
            args.device
        )
        for _ in range(args.model_n)
    ]

    params = list(models[0].parameters())
    for m in models[1:]:
        params += list(m.parameters())

    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    criterion = torch.nn.NLLLoss()

    # best_stats: ClassificationStats = ClassificationStats(args.test_set.dataset, args)
    model_dict: ModelSaveDict = {}

    TEMP = 5
    for epoch in range(args.epochs):
        # fmt: off
        [m.train() for m in models]
        for i, data in enumerate(zip(*[args.train_set for _ in range(args.model_n * 2)])):
            xs = torch.stack([x.to(args.device, non_blocking=True) for (x, _) in data[: args.model_n]])
            ys = torch.stack([y.to(args.device, non_blocking=True) for (_, y) in data[: args.model_n]])
            # idx = torch.stack([idx.to(args.device) for (_, _, idx) in data])

            x_phis = torch.stack([x.to(args.device, non_blocking=True) for (x, _) in data[args.model_n :]])

            yhats = torch.stack([m(x) for (m, x) in zip(models, xs)]).view(-1, args.y_dim)
            yhats, temp = yhats[:, :-1], yhats[:, -1:]
            loss = criterion(
                torch.log_softmax(yhats, dim=1), ys.view(-1)
            )

            # if epoch > 0:
            # if epoch > args.epochs // 2:
            _yhat_phi, _dist = [], []
            for (m, x) in zip(models, xs):
                y, _, _, d = m.phi(x, perturber, theta=True)

                _yhat_phi.append(y)
                _dist.append(d)

            yhat_phi = torch.stack(_yhat_phi)
            temp_phi = yhat_phi.view(-1, args.y_dim)[:, -1]
            dist = torch.stack(_dist)

            # kl divergence
            d, _ = dist.min(dim=1)
            w = 1 - torch.exp(-d.detach() / (2 * args.ls ** 2))
            loss_phi = (((TEMP ** w.view(-1)) - torch.exp(temp_phi)) ** 2).mean()
            # kl = -softmax_entropy(yhat_phi, dim=2)

            # if i % 500 == 0:
            #     print(
            #         f"loss: {loss.item():.4f}, phi loss: {loss_phi.sum().item():.4f} "
            #         # f"temp: min: {temp.min():.4f}, max: {temp.max():.4f}, "
            #         # f"dist min: {dist.min():.4f} max: {dist.max():.4f} "
            #         f"tmp: min: {temp.min():.4f} max: {temp.max():.4f} "
            #         f"tmp phi: min: {temp_phi.min():.4f} max: {temp_phi.max():.4f} "
            #         f"w: min {w.min():.4f} max {w.max():.4f} "
            #         f"d: {d.min():.4f} {d.max():.4f} "
            #     )
            loss += loss_phi.mean() / args.model_n

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss = 0.0  # type: ignore
            for (m, x) in zip(models, x_phis):
                y, h, s, d = m.phi(x, perturber, theta=False)
                y = y[:, :-1]

                d = d[s[0], s[1]]

                loss += softmax_entropy(y).mean() + d.mean() - h

            loss = loss / args.model_n

            phi_opt.zero_grad()
            loss.backward()
            phi_opt.step()

        # fmt: on
        if epoch % 10 == 0:
            [m.eval() for m in models]
            t = ClassificationStatsTemp(args.test_set.dataset, args, 10)
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = (
                        x.to(args.device, non_blocking=True),
                        y.to(args.device, non_blocking=True),
                    )

                    yhat = torch.stack([m(x) for m in models]).mean(dim=0)
                    t.set_multiclass(y, yhat)

            t.calc_stats()
            print(f"epoch: {epoch} accuracy: {t.accuracy:.4f} cal: {t.cal_error:.4f}")

    # fmt: off
    model_dict[os.path.join(args.get_model_dir(run, classification=True), "perturber.pt")] = (copy.deepcopy(perturber).cpu().state_dict())
    for i, m in enumerate(models):
        path = os.path.join(args.get_model_dir(run, classification=True), f"model-{i}.pt")
        model_dict[path] = copy.deepcopy(m).cpu().state_dict()
    # fmt: on

    return model_dict


def eval(args: Args, run: int, model_dict: ModelSaveDict) -> ClassificationStatsTemp:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = []
    for i in range(args.model_n):
        m = Model(
            (args.x_dim, args.x_dim), args.h_dim, args.y_dim, get_conv_args(args),
        ).to(args.device)

        path = os.path.join(
            args.get_model_dir(run, classification=True), f"model-{i}.pt"
        )
        m.load_state_dict(model_dict[path])
        m.eval()
        models.append(m.to(args.device))

    t = ClassificationStatsTemp(args.test_set.dataset, args, 10)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = (
                x.to(args.device, non_blocking=True),
                y.to(args.device, non_blocking=True),
            )

            yhat = torch.stack([m(x) for m in models]).mean(dim=0)
            t.set_multiclass(y, yhat)

    t.calc_stats()
    return t
