"""Experiment script to estimate the bias, variance, and diversity terms for
the CIFAR10 CNN models in the standard parameterization.
"""
from __future__ import annotations

from typing import Any

import os

from torch.nn import Softmax, CrossEntropyLoss, KLDivLoss, LogSoftmax

from data_proc.cifar10 import PrepareCorruptCIFAR
from torch.utils.data import DataLoader
import torch

from helpers.hooks import GetHookVals
from env.user import PROJECT_PATH
from helpers.saving import save_to_pickle
from helpers.logger import get_logger

logger = get_logger()
logger.setLevel(20)


def load_model_and_weights(to_ckpt: str):
    """Load Pytorch model and weights."""
    model = torch.load(to_ckpt, weights_only=False, map_location="cuda:0")

    param_list = list(model.named_parameters())

    # extract last layer's weights
    weights = param_list[-1][1]
    weights.requires_grad=False
    
    return model, weights


def activations_matrix(
        model: torch.Module,
        x: torch.Tensor,
) -> tuple[dict[str, Any], torch.Tensor, torch.Tensor]:
    """Fetch activation values given a Pytorch model and an input.
    
    Model is assumed to be the MUP MLP in the /archs folder.
    """
    with torch.no_grad():
        model.eval()

        hookVal = GetHookVals()

        module_list = list(model.named_modules())

        # add hooks to final layer
        layer_num = module_list[-2][0].split(".")
        layer_num = int(layer_num[-1])
        h1 = model.net[layer_num].register_forward_hook(
            hookVal.getActivation("flatten"),
        )

        out = model.forward(x)

        h1.remove()

    return hookVal.activation

def combiner(
        subpredictors: int,
        test_preds: torch.Tensor,
) -> torch.Tensor:
    """Combine the subpredictor outputs."""
    # We first log-normalize the subpredictor logits
    lsm = LogSoftmax(dim=1)
    combiner = lsm(test_preds)
    combiner = torch.sum(combiner, dim=2)
    combiner *= (1/subpredictors)

    # we don't normalize the combiner yet.    
    return combiner

def estimate_centroid(
        trials:int,
        test_preds: torch.Tensor,
) -> torch.Tensor:
    """Compute the centroids for all subpredictors."""
    # first log-normalize the logits from the test_preds tensor
    lsm = LogSoftmax(dim=1)
    test_preds_normalized = lsm(test_preds)

    # We construct the centroid but, we don't normalize it yet
    centroid = torch.sum(test_preds_normalized, dim=0)
    centroid *= 1/trials

    return centroid

def estimate_bias_var_div(
        trials: int,
        subpredictors: int,
        num_classes: int,
        testset: torch.nn.Module,
        experiment_name: str,
        batch_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    """Estimate the bias, variance, and diversity of a neural network.
    
    Computation is designed to run on a GPU. If a memory allocation error is
    received, tune the batch size.
    """
    # retrieve model files
    model_folder = os.path.join(PROJECT_PATH, "models")
    model_folder = os.path.join(model_folder, experiment_name)
    model_folder = os.path.join(model_folder, str(subpredictors))
    dir_contents = os.listdir(
        path=model_folder,
    )
    dir_contents = [x for x in dir_contents if "best" in x] # only fetch best
    dir_contents.sort()
    
    # create test set dataloader
    testloader = DataLoader(
        dataset=testset,
        batch_size=batch_size,
        shuffle=False,
    )
    
    # initialize bias, variance, and diversity terms
    bias_term = torch.tensor(0.0, device="cuda")
    variance_term = torch.tensor(0.0, device="cuda")
    diversity_term = torch.tensor(0.0, device="cuda")

    ce_loss = CrossEntropyLoss(reduction="none")
    kl_loss = KLDivLoss(reduction="none")
    sm_est = Softmax(dim=1)
    sm_comb = Softmax(dim=2)
    lsm_sub = LogSoftmax(dim=2)
    for batch, (x, y) in enumerate(testloader):
        logger.info(f"Estimating bias, variance, and diversity: computing batch {batch} out of {len(testloader) - 1}")

        # create a results tensor to store subpredictor outputs
        test_preds = torch.zeros(
        trials,
        num_classes,
        subpredictors,
        x.shape[0],
        device="cuda",
        dtype=torch.float64,
        )

        x = x.to("cuda")
        for k, trial in enumerate(dir_contents):
            # retrieve model and weights trained on trial set
            to_ckpt = os.path.join(model_folder, trial)
            model, weights_comb = load_model_and_weights(to_ckpt=to_ckpt)
            weights_comb = weights_comb.double()
            
            # retrieve activations
            activations = activations_matrix(model=model, x=x)
            # switch to float64 to avoid incurring large rounding errors
            activations["flatten"] = activations["flatten"].double()

            # compute and store subpredictor outputs based on MFP
            weights = weights_comb.unsqueeze(dim=2)
            # broadcast: (classes, subpreds, batch)
            subpred_logits = torch.mul(weights, activations["flatten"].T)
            subpred_logits *= torch.tensor(
                subpredictors,
                dtype=torch.float64,
                device="cuda",
            )
            test_preds[k, :, :, :] = subpred_logits # we save the unnormalized logits

        centroid_est = estimate_centroid(trials=trials, test_preds=test_preds)
        centroid_est = torch.permute(centroid_est, dims=(2, 1, 0))
        combine = combiner(subpredictors=subpredictors, test_preds=test_preds)
        combine = torch.permute(combine, dims=(0, 2, 1))
        prob_combine = sm_comb(combine)

        for predictor in range(subpredictors):
            loss_est_y = ce_loss(centroid_est[:, predictor, :], y.to("cuda")) # broadcast
            bias_term += torch.sum(loss_est_y)

            prob_est = sm_est(centroid_est[:, predictor, :])
            subpred = test_preds[:, :, predictor, :]
            subpred = torch.permute(subpred, dims=(0, 2, 1))
            log_subpred = lsm_sub(subpred)
            loss_est_sub = kl_loss(log_subpred, prob_est)
            variance_term += torch.sum(loss_est_sub, dim=(0,1,2))

            loss_comb_sub = kl_loss(log_subpred, prob_combine)
            diversity_term += torch.sum(loss_comb_sub, dim=(0,1,2))
    
    bias_term *= (1/subpredictors) * (1/len(testset))
    variance_term *= (1/subpredictors) * (1/len(testset)) * (1/trials)
    diversity_term *= (1/subpredictors) * (1/len(testset)) * (1/trials)

    return bias_term, variance_term, diversity_term, subpredictors

def check_normal_decomposition(
    trials: int,
    subpredictors: int,
    num_classes: int,
    testset: torch.nn.Module,
    experiment_name: str,
) -> tuple[torch.Tensor, torch.Tensor, int]:
    """Estimate the usual bias and variance terms.
    
    Used to compare to the bias-variance-diversity decomposition. Also, serves
    as a check that the risks of the two decompositions matches up to small
    numerical errors.
    """
    # retrieve model files
    model_folder = os.path.join(PROJECT_PATH, "models")
    model_folder = os.path.join(model_folder, experiment_name)
    model_folder = os.path.join(model_folder, str(subpredictors))
    dir_contents = os.listdir(
        path=model_folder,
    )
    dir_contents = [x for x in dir_contents if "best" in x] # only fetch best
    dir_contents.sort()

    # create a test set dataloader
    testloader = DataLoader(
        dataset=testset,
        shuffle=False,
        batch_size=1024,
    )

    # initialize bias and variance term
    bias_term = torch.tensor(0.0, device="cuda", dtype=torch.float64)
    variance_term = torch.tensor(0.0, device="cuda", dtype=torch.float64)

    ce_loss = CrossEntropyLoss(reduction="none")
    kl_loss = KLDivLoss(reduction="none")
    sm_avg = Softmax(dim=1)
    lsm = LogSoftmax(dim=2)
    lsm_comb = LogSoftmax(dim=2)
    for batch, (x, y) in enumerate(testloader):
        logger.info(f"Estimating bias, variance: computing batch {batch} out of {len(testloader) - 1}")
        
        # initialize a results tensor to store model outputs
        test_preds = torch.zeros(
            trials,
            num_classes,
            x.shape[0],
            device="cuda",
            dtype=torch.float64,
        )

        x = x.to("cuda")
        for k, trial in enumerate(dir_contents):
            # fetch model trained on trial set
            to_ckpt = os.path.join(model_folder, trial)
            model, weights_comb = load_model_and_weights(to_ckpt=to_ckpt)

            # compute model output
            model_logits = model.forward(x)
            model_logits = model_logits.double()
            model_logits = model_logits.detach()
            model_logits = model_logits.T

            test_preds[k, :, :] = model_logits # we save the unnormalized logits

        avg_combine = lsm(torch.permute(test_preds, dims=(0, 2, 1)))
        avg_combine = torch.sum(avg_combine, dim=0)
        avg_combine *= (1/trials)

        loss_est_y = ce_loss(avg_combine, y.to("cuda"))
        bias_term += torch.sum(loss_est_y)

        prob_avg_combine = sm_avg(avg_combine)
        log_combine = lsm_comb(torch.permute(test_preds, dims=(0, 2, 1)))
        loss_comb_avg_comb = kl_loss(log_combine, prob_avg_combine)
        variance_term += torch.sum(loss_comb_avg_comb, dim=(0,1,2))
    
    bias_term *= 1/len(testset)
    variance_term *= (1/len(testset)) * (1/trials)
    
    return bias_term, variance_term, subpredictors


def main():
    #--------------------------------------------------------------------------
    # Experiment details
    experiment_name = 'cifar_cnn_dd_01_label_corruption_small'
    trials = 50

    #--------------------------------------------------------------------------
    # Data

    pcifar10 = PrepareCorruptCIFAR()
    testset = pcifar10.test_set

    # A very clunky approach - fix
    num_subpredictors = [8, 16, 32, 64, 128, 512, 1024, 2048, 4096]
    exp_name = [experiment_name] * len(num_subpredictors)
    num_classes = [10] * len(num_subpredictors)
    total_trials = [trials] * len(num_subpredictors)
    total_testset = [testset] * len(num_subpredictors)
    batch_size = [1024, 1024, 1024, 512, 256, 64, 64, 64, 64] # tune if mem_alloc error

    args = zip(
        total_trials,
        num_subpredictors,
        num_classes,
        total_testset,
        exp_name,
        batch_size,
    )
    results_bvd = []
    results_normal = []
    for k, s, c, t, e, b in args:
        results_bvd.append(
            estimate_bias_var_div(
                trials=k,
                subpredictors=s,
                num_classes=c,
                testset=t,
                experiment_name=e,
                batch_size=b,
            )
        )

        results_normal.append(
            check_normal_decomposition(
                trials=k,
                subpredictors=s,
                num_classes=c,
                testset=t,
                experiment_name=e,
            )
        )

    # save to analysis folder
    folder_path = os.path.join(PROJECT_PATH, "analysis")
    folder_path = os.path.join(folder_path, experiment_name)
    save_to_pickle(
        folder_path=folder_path,
        file_name="results_subpred_sp_cifar10",
        file=results_bvd,
        safe_mode=False,
    )
    save_to_pickle(
        folder_path=folder_path,
        file_name="results_normal_sp_cifar10",
        file=results_normal,
        safe_mode=False,
    )

if __name__ == "__main__":
    main()








