import copy
import os

import torch
from torch import nn

from models.r1bnn import Conv2dRank1, Linear
from models.r1bnn import ModelConv as Model
from models.r1bnn import RankOneBayesianVector, RankOneBayesianVectorConv
from utils import Args, ClassificationStats, ModelSaveDict, get_conv_args


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(
        (args.x_dim, args.x_dim),
        args.h_dim,
        args.y_dim,
        args.model_n,
        get_conv_args(args),
    ).to(args.device)

    bnns, shared = list(), list()
    for layer in model.children():
        if isinstance(layer, nn.Sequential):
            for sublayer in layer.children():
                if isinstance(sublayer, Conv2dRank1):
                    shared += list(sublayer.conv2d.parameters())
                    bnns += list(sublayer.s_vector.parameters()) + list(
                        sublayer.r_vector.parameters()
                    )

        elif isinstance(layer, Linear):
            shared += list(layer.parameters())
        elif isinstance(layer, RankOneBayesianVector) or isinstance(
            layer, RankOneBayesianVectorConv
        ):
            bnns += list(layer.parameters())
        else:
            raise ValueError(f"got an unknown layer type: {layer} {type(layer)}")

    optimizer = torch.optim.Adam(
        [
            {"params": shared, "lr": args.lr, "weight_decay": args.weight_decay},
            {"params": bnns, "lr": args.lr},
        ]
    )

    criterion = torch.nn.NLLLoss()

    models: ModelSaveDict = {}
    best_stats: ClassificationStats = ClassificationStats(args.test_set.dataset, args)

    for epoch in range(args.epochs):
        model.train()
        for i, data in enumerate(zip(*[args.train_set for _ in range(args.model_n)])):
            x = torch.stack([x.to(args.device) for (x, _) in data])
            y = torch.stack([y.to(args.device) for (_, y) in data])

            # regular loss
            mdl_n, b, ch, l, w = x.size()
            yhat = model(x.view(-1, ch, l, w))
            loss = (
                criterion(
                    torch.log_softmax(yhat.view(-1, args.y_dim), dim=1), y.view(-1)
                )
                + args.kl_beta * model.kl()
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            model.eval()
            t = ClassificationStats(args.test_set.dataset, args)
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = x.to(args.device), y.to(args.device)

                    b, ch, l, w = x.size()
                    mus = model.mc(x.view(-1, ch, l, w), args.samples).mean(dim=0)
                    t.set_multiclass(y, mus)

            t.calc_stats()
            print(f"epoch: {epoch} acc: {t.accuracy:.4f}")

        if t.accuracy > best_stats.accuracy:
            best_stats = t

    model.zero_kl()
    models[
        os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")
    ] = (copy.deepcopy(model).cpu().state_dict())
    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> ClassificationStats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(
        (args.x_dim, args.x_dim),
        args.h_dim,
        args.y_dim,
        args.model_n,
        get_conv_args(args),
    ).to(args.device)
    model.load_state_dict(
        models[os.path.join(args.get_model_dir(run, classification=True), "model.pt")]
    )

    model.eval()
    t = ClassificationStats(args.test_set.dataset, args)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            yhat = model.mc(x, args.samples).mean(dim=0)
            t.set_multiclass(y, yhat)

    t.calc_stats()
    return t
