import copy
import os

import torch
from torch import nn

from models.r1bnn import Conv2dRank1, Linear
from models.r1bnn import ModelConv as Model
from models.r1bnn import RankOneBayesianVector, RankOneBayesianVectorConv
from models.set_encoder import LatentPerturber
from utils import (Args, ClassificationStatsTemp, ModelSaveDict, get_conv_args,
                   softmax_entropy)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    conv_args = get_conv_args(args)
    perturber = LatentPerturber(conv_args[-1][1]).to(args.device)
    phi_opt = torch.optim.Adam(perturber.parameters(), lr=args.beta_lr)

    model = Model(
        (args.x_dim, args.x_dim), args.h_dim, args.y_dim, args.model_n, conv_args,
    ).to(args.device)

    bnns, shared = list(), list()
    for layer in model.children():
        if isinstance(layer, nn.Sequential):
            for sublayer in layer.children():
                if isinstance(sublayer, Conv2dRank1):
                    shared += list(sublayer.conv2d.parameters())
                    bnns += list(sublayer.s_vector.parameters()) + list(
                        sublayer.r_vector.parameters()
                    )

        elif isinstance(layer, Linear):
            shared += list(layer.parameters())
        elif isinstance(layer, RankOneBayesianVector) or isinstance(
            layer, RankOneBayesianVectorConv
        ):
            bnns += list(layer.parameters())
        else:
            raise ValueError(f"got an unknown layer type: {layer} {type(layer)}")

    optimizer = torch.optim.Adam(
        [
            {"params": shared, "lr": args.lr, "weight_decay": args.weight_decay},
            {"params": bnns, "lr": args.lr},
        ]
    )

    criterion = torch.nn.NLLLoss()

    models: ModelSaveDict = {}
    best_stats: ClassificationStatsTemp = ClassificationStatsTemp(
        args.test_set.dataset, args, args.y_dim
    )

    TEMP = 5
    for epoch in range(args.epochs):
        model.train()
        for i, data in enumerate(
            zip(*[args.train_set for _ in range(args.model_n * 2)])
        ):
            # fmt: off
            x = torch.stack([x.to(args.device, non_blocking=True) for (x, _) in data[: args.model_n]])
            y = torch.stack([y.to(args.device, non_blocking=True) for (_, y) in data[: args.model_n]])
            x_phi = torch.stack([x.to(args.device, non_blocking=True) for (x, _) in data[args.model_n :]])
            # fmt: on

            # regular loss
            mdl_n, b, ch, ln, wd = x.size()
            yhat = model(x.view(-1, ch, ln, wd)).view(-1, args.y_dim)
            yhat, temp = yhat[:, :-1], yhat[:, -1]

            loss = (
                criterion(torch.log_softmax(yhat, dim=1), y.view(-1))
                + args.kl_beta * model.kl()
            )

            yhat_phi, _, _, dist = model.phi(
                x.view(-1, ch, ln, wd), perturber, theta=True
            )
            yhat_phi = yhat_phi.view(-1, args.y_dim)
            yhat_phi, temp_phi = yhat_phi[:, :-1], yhat_phi[:, -1]

            d, _ = dist.min(dim=0)
            w = 1 - torch.exp(-d.detach() / (2 * args.ls ** 2))  # type: ignore

            loss_phi = (((TEMP ** w) - torch.exp(temp_phi)) ** 2).mean()

            # if i % 500 == 0:
            #     print(
            #         f"loss: {loss.item():.4f}, phi loss: {loss_phi.sum().item():.4f} "
            #         # f"temp: min: {temp.min():.4f}, max: {temp.max():.4f}, "
            #         # f"dist min: {dist.min():.4f} max: {dist.max():.4f} "
            #         f"tmp: min: {temp.min():.4f} max: {temp.max():.4f} "
            #         f"tmp phi: min: {temp_phi.min():.4f} max: {temp_phi.max():.4f} "
            #         f"w: min {w.min():.4f} max {w.max():.4f} "
            #         f"d: {d.min():.4f} {d.max():.4f} "
            #     )

            loss += loss_phi.mean() / args.model_n

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            yhat_phi, h, subsets_idx, dist = model.phi(
                x_phi.view(-1, ch, ln, wd), perturber, theta=False
            )
            yhat_phi = yhat_phi.view(-1, args.y_dim)[:, :-1]

            entropy = softmax_entropy(yhat_phi, dim=1).mean()
            dist = dist[subsets_idx[0], subsets_idx[1]]

            # print(f"entropy: {entropy} dist: {dist.mean()} h: {h}")
            loss = entropy + dist.mean() - h

            phi_opt.zero_grad()
            loss.backward()
            phi_opt.step()

        if epoch % 10 == 0:
            model.eval()
            t = ClassificationStatsTemp(
                args.test_set.dataset, args, 10, length=100 * args.batch_size * 10,
            )
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = x.to(args.device), y.to(args.device)

                    mus = model.mc(x, args.samples).mean(dim=0)
                    t.set_multiclass(y, mus)

                    if i == 99:
                        break

            t.calc_stats()
            print(f"epoch: {epoch} acc: {t.accuracy:.4f} cal: {t.cal_error:.4f}")

            if t.accuracy > best_stats.accuracy:
                best_stats = t

    # fmt: off
    model.zero_kl()
    models[os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")] = (copy.deepcopy(model).cpu().state_dict())
    models[os.path.join(f"{args.get_model_dir(run, classification=True)}", "perturber.pt")] = (copy.deepcopy(perturber).cpu().state_dict())
    # fmt: on
    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> ClassificationStatsTemp:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(
        (args.x_dim, args.x_dim),
        args.h_dim,
        args.y_dim,
        args.model_n,
        get_conv_args(args),
    ).to(args.device)
    model.load_state_dict(
        models[os.path.join(args.get_model_dir(run, classification=True), "model.pt")]
    )

    model.eval()
    t = ClassificationStatsTemp(args.test_set.dataset, args, 10)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            yhat = model.mc(x, args.samples).mean(dim=0)
            t.set_multiclass(y, yhat)

    t.calc_stats()
    return t
