import os

import torch

from models.deep_ensemble import ModelConv as Model
from models.set_encoder import LatentPerturber
from models.swag import SWAG
from utils import (Args, ClassificationStatsTemp, ModelSaveDict,
                   adjust_learning_rate, get_conv_args, softmax_entropy)


def schedule(epoch: int, args: Args) -> float:
    t = epoch / args.swag_start
    lr_ratio = args.swag_lr / args.swag_lr_init
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.swag_lr_init * factor


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    conv_args = get_conv_args(args)
    perturber = LatentPerturber(conv_args[-1][1]).to(args.device)
    phi_opt = torch.optim.Adam(perturber.parameters(), lr=args.beta_lr)

    model_args = [(args.x_dim, args.x_dim), args.h_dim, args.y_dim, conv_args]
    model = Model(*model_args).to(args.device)  # type: ignore

    swag_model = SWAG(
        Model,
        args.device,
        no_cov_mat=False,
        max_num_models=args.model_n,
        model_args=model_args,
    ).to(args.device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.swag_lr_init,
        momentum=args.swag_momentum,
        weight_decay=args.weight_decay,
    )

    criterion = torch.nn.NLLLoss()
    args.swag_start = args.epochs // 2

    TEMP = 5
    for epoch in range(args.epochs):
        lr = schedule(epoch, args)
        adjust_learning_rate(optimizer, lr)
        model.train()
        for i, ((x, y), (x_phi, _)) in enumerate(zip(args.train_set, args.train_set)):

            x, y = (
                x.to(args.device, non_blocking=True),
                y.to(args.device, non_blocking=True),
            )
            x_phi = x_phi.to(args.device, non_blocking=True)

            yhat = model(x)
            yhat, temp = yhat[:, :-1], yhat[:, -1:]

            loss = criterion(torch.log_softmax(yhat, dim=1), y)

            # if epoch > args.epochs // 2:
            yhat_phi, _, _, dist = model.phi(x, perturber, theta=True)
            yhat_phi, temp_phi = yhat_phi[:, :-1], yhat_phi[:, -1]

            # kl divergence
            d, _ = dist.min(dim=0)
            w = 1 - torch.exp(-d.detach() / (2 * args.ls ** 2))  # type: ignore

            loss_phi = ((TEMP ** w - torch.exp(temp_phi)) ** 2).mean()

            # if i % 500 == 0:
            #     print(
            #         f"loss: {loss.item():.4f}, phi loss: {loss_phi.sum().item():.4f} "
            #         f"temp: min: {temp.min():.4f}, max: {temp.max():.4f}, "
            #         # f"dist min: {dist.min():.4f} max: {dist.max():.4f} "
            #         # f"tmp: min: {temp.min():.4f} max: {temp.max():.4f} "
            #         # f"tmp phi: min: {temp_phi.min():.4f} max: {temp_phi.max():.4f} "
            #         f"w: min {w.min():.4f} max {w.max():.4f} "
            #         f"d: {d.min():.4f} {d.max():.4f} "
            #     )

            loss += loss_phi

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            yhat_phi, h, subsets_idx, dist = model.phi(x_phi, perturber, theta=False)
            yhat_phi = yhat_phi[:, :-1]

            entropy = softmax_entropy(yhat_phi).mean()
            dist = dist[subsets_idx[0], subsets_idx[1]]

            # print(f"entropy: {entropy} dist: {dist.mean()} h: {h}")
            loss_phi = entropy + dist.mean() - h

            phi_opt.zero_grad()
            loss_phi.backward()
            phi_opt.step()

        if epoch >= args.swag_start:
            swag_model.collect_model(model)
            if epoch == 0 or epoch % 5 == 0 or epoch == args.epochs - 1:
                swag_model.sample(0.0)

        if epoch % 10 == 0:
            t = ClassificationStatsTemp(
                args.test_set.dataset, args, 10, length=100 * args.batch_size * 10,
            )
            model.eval()
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = x.to(args.device), y.to(args.device)

                    yhat = model(x)
                    t.set_multiclass(y, yhat)

                    if i == 99:
                        break

            t.calc_stats()
            print(f"epoch: {epoch} accuracy: {t.accuracy:.4f} cal: {t.cal_error:.4f}")

    models: ModelSaveDict = {
        os.path.join(
            args.get_model_dir(run, classification=True), "model.pt",
        ): swag_model.cpu().state_dict(),
        os.path.join(
            args.get_model_dir(run, classification=True), "perturber.pt",
        ): perturber.cpu().state_dict(),
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> ClassificationStatsTemp:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    model_args = [(args.x_dim, args.x_dim), args.h_dim, args.y_dim, get_conv_args(args)]
    model = Model(*model_args).to(args.device)  # type: ignore
    swag_model = SWAG(
        Model,
        args.device,
        no_cov_mat=False,
        max_num_models=args.model_n,
        model_args=model_args,
    ).to(args.device)

    swag_model.load_state_dict(
        models[os.path.join(args.get_model_dir(run, classification=True), "model.pt",)]
    )

    t = ClassificationStatsTemp(args.test_set.dataset, args, 10)
    model.eval()
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            yhat = swag_model.mc_class(x, args.samples).mean(dim=0)
            t.set_multiclass(y, yhat)

    t.calc_stats()
    return t
