import copy
import os

import torch

from models.mc_dropout import MCDropoutHeteroConv
from utils import (Args, ClassificationStats, ModelSaveDict, dropout_eval,
                   get_conv_args)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train r test set cannot be None")

    model = MCDropoutHeteroConv(
        (args.x_dim, args.x_dim), args.h_dim, args.ps, args.y_dim, get_conv_args(args)
    ).to(args.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    criterion = torch.nn.NLLLoss()

    models: ModelSaveDict = {}
    best_stats: ClassificationStats = ClassificationStats(args.test_set.dataset, args)

    for epoch in range(args.epochs):
        model.train()
        for i, (x, y) in enumerate(args.train_set):
            x, y = x.to(args.device), y.to(args.device)

            mu = model(x)

            loss = criterion(torch.log_softmax(mu, dim=1), y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            dropout_eval(model)
            t = ClassificationStats(
                args.test_set.dataset, args, length=100 * args.batch_size * 10
            )
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = x.to(args.device), y.to(args.device)

                    mus = model.mc(x, args.samples).mean(dim=0)
                    t.set_multiclass(y, mus)

                    if i == 99:
                        break

            t.calc_stats()
            print(
                f"epoch: {epoch} acc: {t.accuracy:.4f}, cal: {t.cal_error:.4f} nll: {t.nll:.4f}"
            )
            if t.nll > best_stats.nll:
                best_stats = t

    models[
        os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")
    ] = (copy.deepcopy(model).cpu().state_dict())

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> ClassificationStats:
    if args.test_set is None:
        raise ValueError("train r test set cannot be None")

    model = MCDropoutHeteroConv(
        (args.x_dim, args.x_dim), args.h_dim, args.ps, args.y_dim, get_conv_args(args)
    ).to(args.device)
    model.load_state_dict(
        models[
            os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")
        ]
    )

    # run the testing loop
    dropout_eval(model)
    t = ClassificationStats(args.test_set.dataset, args)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)
            # x, y = x.to(args.device), y.to(args.device)

            mus = model.mc(x, args.samples).mean(dim=0)
            t.set_multiclass(y, mus)

    t.calc_stats()
    return t
