"""Experiment script to estimate the bias, variance, diversity and risk for
subpredictors computed using the mean-field parameterization."""

from __future__ import annotations

from typing import Any

import os

from torch.nn import MSELoss

from data_proc.california_housing import PrepareCaliforniaHousing, CaliforniaHousing
from torch.utils.data import DataLoader
import torch

from helpers.hooks import GetHookVals
from env.user import PROJECT_PATH
from helpers.saving import save_to_pickle
from helpers.logger import get_logger

logger = get_logger()
logger.setLevel(20)


def load_model_and_weights(to_ckpt: str):
    """Load Pytorch model and weights."""
    model = torch.load(to_ckpt, weights_only=False)

    weights = model.ho.weight
    weights.requires_grad = False
    
    return model, weights

def activations_matrix(
        model: torch.Module,
        x: torch.Tensor,
) -> tuple[dict[str, Any], torch.Tensor, torch.Tensor]:
    """Fetch activation values given a Pytorch model and an input.
    
    Model is assumed to be the MFP MLP in the /archs folder.
    """
    with torch.no_grad():
        model.eval()

        hookVal = GetHookVals()

        h1 = model.relu_ih.register_forward_hook(hookVal.getActivation("relu"))

        out = model.forward(x)

        h1.remove()

    return hookVal.activation

def combiner(
        subpredictors: int,
        test_preds: torch.Tensor,
) -> torch.Tensor:
    """Combine the subpredictor outputs."""
    return (1/subpredictors) * torch.sum(test_preds, dim=1)

def estimate_centroid(
        trials:int,
        test_preds: torch.Tensor,
) -> torch.Tensor:
    """Compute the centroids for all subpredictors."""
    return (1/trials) * torch.sum(test_preds, dim=0)

def estimate_bias_var_div(
        trials: int,
        subpredictors: int,
        testset: torch.nn.Module,
        experiment_name: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Estimate the bias, variance, and diversity of a neural network.
    
    Computation is designed to run on a GPU. If a memory allocation error is
    received, tune the batch size.
    """
    # retrieve model files
    model_folder = os.path.join(PROJECT_PATH, "models")
    model_folder = os.path.join(model_folder, experiment_name)
    model_folder = os.path.join(model_folder, str(subpredictors))
    dir_contents = os.listdir(
        path=model_folder,
    )
    dir_contents = [x for x in dir_contents if "best" in x] # only fetch best
    dir_contents.sort()
    
    # create test set dataloader
    testloader = DataLoader(
        dataset=testset,
        batch_size=128,
        shuffle=False,
    )
    
    # initialize bias, variance, and diversity terms
    bias_term = torch.tensor(0.0, device="cuda")
    variance_term = torch.tensor(0.0, device="cuda")
    diversity_term = torch.tensor(0.0, device="cuda")

    loss = MSELoss(reduction="none")
    cur_index = 0
    for batch, (x, y) in enumerate(testloader):
        logger.info(f"Estimating bias, variance, and diversity: computing batch {batch} out of {len(testloader) - 1}")

        # create a results tensor to store subpredictor outputs
        test_preds = torch.zeros(
        trials,
        subpredictors,
        x.shape[0],
        device="cuda",
        )

        x = x.to("cuda")
        for k, trial in enumerate(dir_contents):
            # retrieve model and weights trained on trial set
            to_ckpt = os.path.join(model_folder, trial)
            model, weights_comb = load_model_and_weights(to_ckpt=to_ckpt)
            weights_comb = weights_comb.double()
            
            # retrieve activations
            activations = activations_matrix(model=model, x=x)
            activations["relu"] = activations["relu"].double()

            # compute and store subpredictor outputs based on MFP
            subpred = torch.mul(weights_comb, activations["relu"])
            test_preds[k, :, :] = torch.transpose(subpred, dim0=0, dim1=1)

        est = estimate_centroid(trials=trials, test_preds=test_preds)
        combine = combiner(subpredictors=subpredictors, test_preds=test_preds)

        loss_est_y = loss(est, y.to("cuda").T) # broadcast
        bias_term += torch.sum(loss_est_y, dim=(0,1))

        loss_est_sub = loss(est, test_preds)
        variance_term += torch.sum(loss_est_sub, dim=(0,1,2))

        loss_comb_sub = loss(combine, torch.transpose(test_preds, dim0=0, dim1=1))
        diversity_term += torch.sum(loss_comb_sub, dim=(0,1,2))

        cur_index += x.shape[0]

    bias_term *= (1/subpredictors) * (1/len(testset))
    variance_term *= (1/subpredictors) * (1/len(testset)) * (1/trials)
    diversity_term *= (1/subpredictors) * (1/len(testset)) * (1/trials)

    return bias_term, variance_term, diversity_term, subpredictors


def check_normal_decomposition(
        trials: int,
        subpredictors: int,
        testset: torch.Module,
        experiment_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Estimate the usual bias and variance terms.
    
    Used to compare to the bias-variance-diversity decomposition. Also, serves
    as a check that the risks of the two decompositions matches up to small
    numerical errors.
    """
    # retrieve model files
    model_folder = os.path.join(PROJECT_PATH, "models")
    model_folder = os.path.join(model_folder, experiment_name)
    model_folder = os.path.join(model_folder, str(subpredictors))
    dir_contents = os.listdir(
        path=model_folder,
    )
    dir_contents = [x for x in dir_contents if "best" in x] # only fetch best
    dir_contents.sort()
    
    # create a test set dataloader
    testloader = DataLoader(
        dataset=testset,
        shuffle=False,
        batch_size=128,
    )

    # initialize bias and variance term
    bias_term = torch.tensor(0.0, device="cuda")
    variance_term = torch.tensor(0.0, device="cuda")

    loss = MSELoss(reduction="none")
    cur_index = 0
    for batch, (x, y) in enumerate(testloader):
        logger.info(f"Estimating bias, variance: computing batch {batch} out of {len(testloader) - 1}")

        # initialize a results tensor to store model outputs
        test_preds = torch.zeros(
        trials,
        x.shape[0],
        device="cuda",
        )
        x = x.to("cuda")
        for k, trial in enumerate(dir_contents):
            # fetch model trained on trial set
            to_ckpt = os.path.join(model_folder, trial)
            model, weights_comb = load_model_and_weights(to_ckpt=to_ckpt)
            
            # compute model output
            model_logits = model.forward(x)
            model_logits = model_logits.double()
            model_logits = model_logits.detach()
            model_logits = model_logits.T
            test_preds[k, :] = model_logits
        
        avg_combine = (1/trials) * torch.sum(test_preds, dim=0)
            
        loss_avg_comb_y = loss(avg_combine, y.to("cuda").squeeze())
        bias_term += torch.sum(loss_avg_comb_y)
        
        loss_avg_comb_comb = loss(avg_combine, test_preds)
        variance_term += torch.sum(loss_avg_comb_comb, dim=(0,1))

        cur_index += x.shape[0]
    
    bias_term *= 1/len(testset)
    variance_term *= (1/len(testset)) * (1/trials)
    
    return bias_term, variance_term, subpredictors


def main():
    #--------------------------------------------------------------------------
    # Experment details
    experiment_name = 'small_calif_housing_mfp_fixed_epoch'

    trials = 50

    #--------------------------------------------------------------------------
    # Data

    calif = PrepareCaliforniaHousing()
    testset = CaliforniaHousing(
        Xarray=calif.X_test,
        yarray=calif.y_test,
    )

    # setup
    num_subpredictors = [5, 10, 50, 100, 500, 1000, 5000]
    exp_name = [experiment_name] * len(num_subpredictors)
    total_trials = [trials] * len(num_subpredictors)
    total_testset = [testset] * len(num_subpredictors)
    args = zip(total_trials, num_subpredictors, total_testset, exp_name)

    # compute decompositions
    results_bvd = []
    results_norm = []
    for k, s, t, e in args:
        results_bvd.append(
            estimate_bias_var_div(
                trials=k,
                subpredictors=s,
                testset=t,
                experiment_name=e,
            ),
        )
        results_norm.append(
            check_normal_decomposition(
                trials=k,
                subpredictors=s,
                testset=t,
                experiment_name=e,
            ),
        )

    # save results to analysis folder
    folder_path = os.path.join(PROJECT_PATH, "analysis")
    folder_path = os.path.join(folder_path, experiment_name)
    save_to_pickle(
        folder_path=folder_path,
        file_name="results_subpred_mfp_fixed_epoch_best",
        file=results_bvd,
        safe_mode=False,
    )
    save_to_pickle(
        folder_path=folder_path,
        file_name="results_normal_mfp_fixed_epoch_best",
        file=results_norm,
        safe_mode=False,
    )

if __name__ == "__main__":
    main()








