from pathlib import Path
import argparse
import random
import re
import json
import logging
import logging.config
from time import time

import torch
import numpy as np

from models.get_model import get_model
from utils.enums import Devices, Datasets
from utils.functions import read_state_dict, get_log_config
from compute.grad_wrt_to_inputs import compute_jacobian
from compute.matrix_2norm import compute_matrix_2norm_power_method_batched


def compute_jac_average_model(models, images, out_dim):
    images.requires_grad_(True)

    jacs_dims = [images.shape[0]] + [out_dim] + list(images.shape[1:])
    mean_jacs = torch.zeros(jacs_dims).to(images.device)

    for model in models:
        out = model(images)

        jacs = compute_jacobian(out, images)
        mean_jacs += jacs

    mean_jacs /= len(models)

    # to compute the 2-norm, we flatten the jacobian in the dimension of the input
    mean_jacs = mean_jacs.flatten(start_dim=2, end_dim=-1)
    images.requires_grad_(False)

    return mean_jacs


def compute_func_var_and_bias2(images, labels, models, out_dim):
    # here we use the following formula to compute the variance:
    # E_xtest[Var(f_θ(x,ζ))] = E_xtest[E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2]]
    # and the following formula to compute the bias squared:
    # E_xtest[bias^2] = E_xtest[|| y(x) - E_ζ[f_θ(x,ζ)] ||^2]

    # array of f_θ(x,ζ)
    funcs = []
    for model in models:
        func = model(images)
        funcs.append(func)

    # shape len(models) x batch_size x output_dim
    funcs = torch.stack(funcs)
    assert len(funcs.shape) == 3
    assert funcs.shape[0] == len(models)
    assert funcs.shape[1] == images.shape[0]
    assert funcs.shape[2] == out_dim

    # E_ζ[f_θ(x,ζ)]
    expected_func = torch.mean(funcs, axis=0)
    assert len(expected_func.shape) == 2
    assert expected_func.shape[0] == images.shape[0]
    assert expected_func.shape[1] == out_dim

    var = torch.zeros(images.shape[0]).to(images.device)
    for i, _ in enumerate(models):
        # || f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2
        norms_sq = torch.linalg.vector_norm(funcs[i] - expected_func, 2, dim=1) ** 2
        # vector of norms of size: batch_size
        assert norms_sq.shape == torch.Size([images.shape[0]])
        var += norms_sq

    # E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2]
    var /= len(models)

    # || y(x) - E_ζ[f_θ(x,ζ)] ||^2
    bias2 = torch.linalg.vector_norm(labels - expected_func, 2, dim=1) ** 2
    # vector of norms of size: batch_size
    assert bias2.shape == torch.Size([images.shape[0]])

    # var = array of [ E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2] ]
    # bias2 = array of [ || y(x) - E_ζ[f_θ(x,ζ)] ||^2 ]
    return var, bias2


if __name__ == "__main__":
    # setup random seed
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # parse args
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", dest="path", type=str)
    parser.add_argument("--dataset_path", dest="dataset_path", type=str)
    parser.add_argument("--runs", dest="runs", nargs="+", type=str)
    parser.add_argument("--timestamp", dest="timestamp", type=int, default=int(time()))
    parser.add_argument("--batch_size", dest="batch_size", type=int, default=512)
    parser.add_argument("--device", dest="device", type=str, default="CPU")

    args = parser.parse_args()

    # setup
    path = Path(args.path)
    dataset_path = Path(args.dataset_path)

    name = re.sub("seed_.*\+", "seed_X+", args.runs[0])
    name = re.sub("runtimestamp_.*", f"runtimestamp_{args.timestamp}", name)

    log_config = get_log_config(path, f"Bias-Var-Tradeoff+{name}")
    logging.config.dictConfig(log_config)

    # create dir for results
    (path / "computed").mkdir(exist_ok=True, parents=True)

    device = Devices[args.device].value
    dataset = args.runs[0].split("+")[0]
    model_name = args.runs[0].split("+")[1]

    # start computing
    _, test_ds, train_dl, test_dl, dims = Datasets[dataset].value(
        dataset_path=dataset_path,
        batch_size=args.batch_size,
        batch_size_test=args.batch_size,
    )

    # get models across seeds
    models = []
    for run in args.runs:
        model = get_model(model_name, dims)
        model = model.to(device)

        state_dict = read_state_dict(run, path, device, epoch=-1)
        model.load_state_dict(state_dict["model"])
        models.append(model)

    # compute sup||∇_x fbar(x)||_2 and inf||∇_x fbar(x)||_2
    supremum_norm = torch.tensor(0).to(device)
    infimum_norm = None
    for i, (images, _) in enumerate(train_dl):
        print(f"Processing train batch {i+1}/{len(train_dl)}...")
        images = images.to(device)
        # ∇_x fbar(x)
        mean_jacs = compute_jac_average_model(models, images, dims[-1])
        # ||∇_x fbar(x)||_2
        mean_jacs_norms = compute_matrix_2norm_power_method_batched(mean_jacs)
        cur_sup_norm = torch.max(mean_jacs_norms)
        cur_inf_norm = torch.min(mean_jacs_norms)
        if cur_sup_norm > supremum_norm:
            supremum_norm = cur_sup_norm
        if (infimum_norm is None) or (cur_inf_norm < infimum_norm):
            infimum_norm = cur_inf_norm

    logging.info("fbar norm inf and sup computation done!")

    # compute the variance and bias squared
    var = 0
    bias2 = 0
    var_arr = None
    for i, (images, labels) in enumerate(test_dl):
        logging.info(f"Processing test batch {i+1}/{len(test_dl)}...")
        images = images.to(device)
        labels = labels.to(device)
        # var_batch = array of [ E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2] ] of batchsize
        # bias2_batch = array of [|| y(x) - E_ζ[f_θ(x,ζ)] ||^2] of batchsize
        var_batch, bias2_batch = compute_func_var_and_bias2(images, labels, models, dims[-1])
        var += torch.sum(var_batch).item()
        bias2 += torch.sum(bias2_batch).item()
        if var_arr is None:
            var_arr = var_batch
        else:
            var_arr = torch.hstack([var_arr, var_batch])

    # E_xtest[Var(f_θ(x,ζ))] = E_xtest[E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2]]
    var /= len(test_ds)
    # E_xtest[bias^2] = E_xtest[|| y(x) - E_ζ[f_θ(x,ζ)] ||^2]
    bias2 /= len(test_ds)
    logging.info("Variance and Bias squared computation done!")

    fake_labels = torch.zeros((1, dims[-1])).to(device)
    # compute the variance for x' = 0
    zeros = torch.zeros(dims[0]).unsqueeze(0).to(device)
    var_0, _ = compute_func_var_and_bias2(zeros, fake_labels, models, dims[-1])
    var_0 = torch.sum(var_0).item()
    # E_xtest[Var(f_θ(x,ζ))] = E_xtest[E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2]]
    var_0 /= len(test_ds)

    logging.info("Variance for x'=0 computation done!")

    # compute the variance for x' = x_test_mean
    X_test = torch.stack([i[0] for i in test_ds])
    x_test_mean = torch.mean(X_test, axis=0).unsqueeze(0).to(device)
    var_x_mean, _ = compute_func_var_and_bias2(x_test_mean, fake_labels, models, dims[-1])
    var_x_mean = torch.sum(var_x_mean).item()
    # E_xtest[Var(f_θ(x,ζ))] = E_xtest[E_ζ[|| f_θ(x,ζ) - E_ζ[f_θ(x,ζ)] ||^2]]
    var_x_mean /= len(test_ds)

    logging.info("Variance for x'=x_test_mean computation done!")

    obj = {
        "sup_norm_fbar": float(supremum_norm.cpu().detach().item()),
        "inf_norm_fbar": float(infimum_norm.cpu().detach().item()),
        "func_var": float(var),
        "func_var_0": float(var_0),
        "func_var_x_test_mean": float(var_x_mean),
        "func_var_arr": [float(var_i) for var_i in var_arr],
        "bias2": float(bias2),
    }
    with (path / "computed" / f"{name}.json").open("w") as f:
        json.dump(obj, f)
