import os
import torch
import copy
import torch.nn as nn

from FishLeg.src.optim.FishLeg import *

from .metrics import *

OPTIMIZERS = {"adam": torch.optim.Adam, "fishleg": FishLeg}

LOSSES = {
    "fixedgaussianlikelihood": FixedGaussianLikelihood,
    "crossentropy": nn.CrossEntropyLoss,
    "bernoullilikelihood": BernoulliLikelihood,
    "softmaxlikelihood": SoftMaxLikelihood,
}

ACCURACY = {"mse": MSE_accuracy, "class": class_accuracy}


def get_off_diagonal_elements(M):
    res = M.clone()
    res.diagonal(dim1=-1, dim2=-2).zero_()
    return res


def pretrain(
    model,
    train_loader,
    aux_loader,
    test_loader,
    optimizer,
    loss_func,
    epochs,
    accuracy,
    device,
    model_save_path,
    opt_state=None,
):
    opt_name = optimizer.pop("name")
    if loss_func["name"].lower() == "crossentropy":
        criterion = LOSSES[loss_func["name"].lower()](**loss_func["args"])
    else:
        criterion = LOSSES[loss_func["name"].lower()](
            **loss_func["args"], device=device
        )

    if opt_name.lower() == "fishleg":

        def nll(model, data):
            data_x, data_y = data
            pred_y = model.forward(data_x)
            return criterion.nll(data_y, pred_y)

        def draw(model, data):
            data_x, data_y = data
            pred_y = model.forward(data_x)
            return (data_x, criterion.draw(pred_y))

        opt = FishLeg(
            model,
            draw,
            nll,
            aux_loader,
            likelihood=criterion,
            **optimizer["args"],
            device=device,
        )
    else:
        opt = OPTIMIZERS[opt_name.lower()](model.parameters(), **optimizer["args"])

    if opt_state:
        opt.load_state_dict(opt_state)

    print("--------------------")
    print("Pre-training Phase")
    print("--------------------")

    # Initial eval
    train_loss = 0
    train_acc = 0
    test_loss = 0
    test_acc = 0

    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    aux_losses = []

    model.eval()
    for n, (train_batch_data, train_batch_labels) in enumerate(train_loader, start=1):
        train_batch_data, train_batch_labels = train_batch_data.to(
            device
        ), train_batch_labels.to(device)
        preds = model.forward(train_batch_data)
        loss = criterion(preds, train_batch_labels)

        train_loss += loss.item()
        train_acc += ACCURACY[accuracy.lower()](preds, train_batch_labels).item()

    train_loss = train_loss / n
    train_acc = train_acc / n

    train_losses.append(train_loss)
    train_accs.append(train_acc)

    for m, (test_batch_data, test_batch_labels) in enumerate(test_loader, start=1):
        test_batch_data, test_batch_labels = test_batch_data.to(
            device
        ), test_batch_labels.to(device)
        test_preds = model.forward(test_batch_data)

        test_loss += criterion(test_preds, test_batch_labels).item()

        test_acc += ACCURACY[accuracy.lower()](test_preds, test_batch_labels).item()

    test_loss = test_loss / m
    test_acc = test_acc / m

    test_losses.append(test_loss)
    test_accs.append(test_acc)

    print("--------------------")
    print("Epoch {}/{}".format(0, epochs))
    print(
        "Train Loss: {:.3f}, Test Loss: {:.3f}, Train Accuracy: {:.3f}, Test Accuracy: {:.3f}".format(
            train_loss,
            test_loss,
            train_acc,
            test_acc,
        )
    )

    ## TRAINING LOOP
    for e in range(1, epochs + 1):
        train_loss = 0
        train_acc = 0
        test_loss = 0
        test_acc = 0
        aux_loss = 0

        model.train()

        for n, (train_batch_data, train_batch_labels) in enumerate(
            train_loader, start=1
        ):
            train_batch_data, train_batch_labels = train_batch_data.to(
                device
            ), train_batch_labels.to(device)
            opt.zero_grad()

            preds = model.forward(train_batch_data)

            loss = criterion(preds, train_batch_labels)
            loss.backward()

            opt.step()

            train_loss += loss.item()
            acc = ACCURACY[accuracy.lower()](preds, train_batch_labels).item()
            train_acc += acc
            if opt_name.lower() == "fishleg":
                aux_loss += opt.aux_loss

            if n % 50 == 0:
                print(
                    f"Batch {n}: Train Loss: {loss.item():.3f}, Train Accuracy: {acc:.3f}"
                )
                if opt_name.lower() == "fishleg":
                    print("Auxiliary Loss", opt.aux_loss)

        train_loss = train_loss / n
        train_acc = train_acc / n
        aux_loss = aux_loss / n

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        aux_losses.append(aux_loss)

        model.eval()
        for m, (test_batch_data, test_batch_labels) in enumerate(test_loader, start=1):
            test_batch_data, test_batch_labels = test_batch_data.to(
                device
            ), test_batch_labels.to(device)

            test_preds = model.forward(test_batch_data)

            test_loss += criterion(test_preds, test_batch_labels).item()

            test_acc += ACCURACY[accuracy.lower()](test_preds, test_batch_labels).item()

        test_loss = test_loss / m
        test_acc = test_acc / m

        print("--------------------")
        print("Epoch {}/{}".format(e, epochs))
        print(
            "Train Loss: {:.3f}, Test Loss: {:.3f}, Train Accuracy: {:.3f}, Test Accuracy: {:.3f}".format(
                train_loss,
                test_loss,
                train_acc,
                test_acc,
            )
        )
        if opt_name.lower() == "fishleg":
            print("Auxiliary Loss", aux_loss)

        if test_loss < min(test_losses):
            save_path = os.path.join(model_save_path, f"model_checkpoint_best.pth")
            with open(save_path, mode="wb") as file_:
                torch.save(
                    {
                        "epoch": e,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": opt.state_dict(),
                        "train_loss": train_loss,
                        "test_loss": test_loss,
                    },
                    file_,
                )

        test_losses.append(test_loss)
        test_accs.append(test_acc)

    scores = {
        "Epochs": list(range(0, epochs + 1)),
        "train_losses": train_losses,
        "train_accuracy": train_accs,
        "test_losses": test_losses,
        "test_accuracy": test_accs,
        "aux_losses": aux_losses,
    }

    return scores, opt.state_dict()
