import argparse
import csv
import json
import os
import sys
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from scipy.stats import wishart
from sklearn import kernel_ridge, linear_model
from sklearn.metrics import r2_score
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler

from pycomets.gcm import GCM
from pycomets.pcm import PCM
from pycomets.regression import LM

sys.path.append("path")
from encoders import CompositionEncMix, get_mlp
from invertible_network_utils import construct_invertible_mlp
from latent_spaces import LatentSpace, ProductLatentSpace
from spaces import NRealSpace
from utils import (
    content_style_from_subsets,
    evaluate_prediction,
    powerset,
    topk_gumbel_softmax,
)


def generate_nonlinear_model():
    if not args.grid_search_eval:
        model = MLPRegressor(max_iter=5000)  # lightweight option
    else:
        # grid search is time- and memory-intensive
        model = GridSearchCV(
            kernel_ridge.KernelRidge(kernel="rbf", gamma=0.1),
            param_grid={
                "alpha": [1e0, 0.1, 1e-2, 1e-3],
                "gamma": np.logspace(-2, 2, 4),
            },
            cv=3,
            n_jobs=-1,
        )
    return model


def infer_content_indices_gumbel_softmax(
    args, hzs: dict, content_size_dict: dict
):
    """
    Infer content indices using Gumbel Softmax (content sizes predefined).

    Args:
        args: Arguments for the function.
        hzs (dict): Dictionary containing the hz values.
        content_size_dict (dict): Dictionary containing the content size for each subset.

    Returns:
        dict: Dictionary containing the estimated content indices for each subset and view.
    """
    est_content_dict = {subset: {} for subset in args.subsets}
    for subset in args.subsets:
        for k in subset:
            avg_logits = hzs[k]["hz"].mean(0)[None]
            m = topk_gumbel_softmax(
                k=content_size_dict[subset],
                logits=avg_logits,  # hzs[k]["hz"][0][None],
                tau=1.0,
                hard=True,
            )
            c_ind = torch.where(m)[-1].tolist()  # batch_size, nSk
            est_content_dict[subset][
                k
            ] = c_ind  # this indicies is different for different views
    return est_content_dict


def infer_content_indices(args, hzs, *mode_specific_args):
    """
    Infer the content indices based on the given arguments and mode-specific arguments.

    Args:
        args: The arguments object containing the selection mode.
        hzs: The hzs object.
        mode_specific_args: Additional arguments specific to the selected mode.

    Returns:
        The inferred content indices.

    Raises:
        ValueError: If the selection mode is not supported.
    """
    if args.selection == "ground_truth":
        return args.view_specific_content_indexing
    elif args.selection == "concat":
        return args.est_content_dict
    elif args.selection == "gumbel_softmax":
        return infer_content_indices_gumbel_softmax(
            args, hzs, *mode_specific_args
        )
    else:
        raise ValueError(f"mode={args.selection} not supported")


def generate_data(
    latent_space,
    models,
    num_batches=1,
    batch_size=4096,
    loss_func=None,
    args=None,
):
    """
    Generate data for training or evaluation.

    Args:
        latent_space (LatentSpace): The latent space object used for sampling latent vectors.
        models (dict): A dictionary of models, including the backbone model.
        num_batches (int, optional): The number of batches to generate. Defaults to 1.
        batch_size (int, optional): The batch size. Defaults to 4096.
        loss_func (callable, optional): The loss function to use. Defaults to None.
        args (argparse.Namespace, optional): Additional arguments. Defaults to None.

    Returns:
        tuple: A tuple containing the data dictionary, hz dictionary, and all_z tensor.
            - data_dict (dict): A dictionary containing the generated data for each subset and view.
            - hz_dict (dict): A dictionary containing the computed hz values for each view and subset.
            - all_z (numpy.ndarray): A numpy array containing all the sampled latent vectors.

    """
    models["backbone"].eval()

    if (args.subsets is None) or (args.subsets == []):
        data_dict = {}
    else:
        data_dict = {
            subset: {k: {"c": [], "s": []} for k in subset}
            for subset in args.subsets
        }

    hz_dict = {
        k: {
            "hz": [],  # unified encoded information
            "est_c_ind": {
                s: [] for s in args.subsets if k in s
            },  # for all subsets
        }
        for k in range(args.n_views)
    }

    all_z = []

    with torch.no_grad():
        for _ in range(num_batches):
            zs = latent_space.sample_latent(batch_size)  # [batch_size, n_z]
            all_z += [zs]

            hzs = dict({})

            # compute the estimated latents for each view (using the unified encoder)
            for k in range(args.n_views):
                hz = models["backbone"].view_specific_forward(
                    zs, k, args.S_k
                )  # [batch_size, nz]
                hzs[k] = {
                    "hz": hz
                }  # to compute the readout, preserve ternsor type
                hz_dict[k]["hz"].append(hz.detach().cpu().numpy())

            for subset_idx, subset in enumerate(args.subsets):
                content_z = zs[:, list(args.content_dict[subset])]
                for k_idx, k in enumerate(subset):
                    style_z = zs[:, list(args.style_dict[subset][k])]
                    # z_Sk = zs[:, args.S_k[k]]

                    est_content_dict = infer_content_indices(
                        args, hzs, args.content_size_dict
                    )
                    # append data
                    data_dict[subset][k]["c"].append(
                        content_z.detach().cpu().numpy()
                    )
                    data_dict[subset][k]["s"].append(
                        style_z.detach().cpu().numpy()
                    )

                    hz_dict[k]["est_c_ind"][subset].append(
                        est_content_dict[subset][k]
                    )

        for subset, subset_dict in data_dict.items():
            for k, k_dict in subset_dict.items():
                data_dict[subset][k]["c"] = np.stack(k_dict["c"], axis=0)
                data_dict[subset][k]["s"] = np.stack(k_dict["s"], axis=0)

        for k, v in hz_dict.items():
            hz_dict[k]["hz"] = np.stack(v["hz"], axis=0)
            for subset in hz_dict[k]["est_c_ind"].keys():
                hz_dict[k]["est_c_ind"][subset] = np.stack(
                    v["est_c_ind"][subset], axis=0
                )  # [num_batches, batch_size, ...]

        return data_dict, hz_dict, torch.stack(all_z, 0).detach().cpu().numpy()


def generate_latent_space(args):
    assert args.n_dependent_dims <= args.latent_dim
    latent_spaces_list = []
    Sigma_z_path = os.path.join(args.save_dir, "Sigma_z.csv")
    if not args.evaluate:
        if args.n_dependent_dims == 0:
            Sigma_z = np.eye(args.latent_dim)
        elif args.B is not None:
            BB = np.array(args.B)
            IB = np.eye(args.latent_dim) - BB
            IB_inv = np.linalg.inv(IB)
            Sigma_z = IB_inv @ IB_inv.T
        else:
            # In the non-dependent case, we generate a set of dependent and non-dependent latent variables
            Sigma_z = np.eye(args.latent_dim)
            Sigma_z_dep = wishart.rvs(
                args.n_dependent_dims, np.eye(args.n_dependent_dims), size=1
            )
            Sigma_z[: args.n_dependent_dims, : args.n_dependent_dims] = (
                Sigma_z_dep
            )

        np.savetxt(Sigma_z_path, Sigma_z, delimiter=",")
    else:
        Sigma_z = np.loadtxt(Sigma_z_path, delimiter=",")
        print(Sigma_z)
    space = NRealSpace(args.latent_dim)

    # Here just one latent space
    def sample_latent(space, size, device=device):
        return space.normal(None, 1.0, size, device, Sigma=Sigma_z)

    latent_spaces_list.append(
        LatentSpace(space=space, sample_latent=sample_latent)
    )
    latent_space = ProductLatentSpace(spaces=latent_spaces_list)
    return latent_space


def init_or_load_mixing_fns(device, args):
    """
    Initializes or loads the mixing functions for the multi-view case.

    Args:
        device (torch.device): The device to use for computation.
        args (argparse.Namespace): The command-line arguments.

    Returns:
        torch.nn.ModuleList: The list of mixing functions.
    """
    # Invertible MLP requires the same input and the same output size
    # extend to multi-view case
    print(f"device: {device}")
    F = (
        torch.nn.ModuleList()
    )  # set of mixing functions, not trainable after generated.
    for i in range(args.n_views):
        f_i = construct_invertible_mlp(
            n=len(args.S_k[i]),
            n_layers=args.n_mixing_layer,
            cond_thresh_ratio=0.001,
            n_iter_cond_thresh=25000,
        )
        F.append(f_i)
    if args.evaluate:
        F = torch.nn.ModuleList()
        mixing_fn_state_dict = torch.load(
            os.path.join(args.save_dir, "mixing_fns.pt"),
            map_location=torch.device("cpu"),
        )
        for i, param_dict in mixing_fn_state_dict.items():
            f_i = construct_invertible_mlp(
                n=len(args.S_k[i]),
                n_layers=args.n_mixing_layer,
                cond_thresh_ratio=0.001,
                n_iter_cond_thresh=25000,
            )
            f_i.load_state_dict(param_dict)
            f_i.to(device)
            F.append(f_i)
            # disable gradient descent for mixing functions
            for p in f_i.parameters():
                p.requires_grad = False

    if args.shared_mixing_function:
        F = [F[0]] * args.n_views
    return F


def init_or_load_encoder_models(device, args, encoding_size=None):
    """
    Initialize or load encoder models.

    Args:
        device (torch.device): The device to use for the models.
        args (argparse.Namespace): The command-line arguments.
        encoding_size (int, optional): The size of the encoding. Defaults to None.

    Returns:
        torch.nn.ModuleList: A list of encoder models.
    """
    G = torch.nn.ModuleList()
    for i in range(args.n_views):
        g_i = get_mlp(
            n_in=len(args.S_k[i]),
            n_out=encoding_size or len(args.S_k[i]),
            layers=[
                len(args.S_k[i]) * 10,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 10,
            ],
        )
        G.append(g_i)
        g_i.to(device)
    if args.evaluate:
        G = torch.nn.ModuleList()

        save_path = os.path.join(args.save_dir, "model.pt")
        ckpt = torch.load(save_path, map_location=torch.device("cpu"))

        for i in range(args.n_views):
            g_i = get_mlp(
                n_in=len(args.S_k[i]),
                n_out=encoding_size or len(args.S_k[i]),
                layers=[
                    len(args.S_k[i]) * 10,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 10,
                ],
            )
            g_i.load_state_dict(ckpt[f"encoder_{i}_state_dict"])
            g_i.to(device)
            G.append(g_i)
    if args.shared_encoder:
        G = [G[0]] * args.n_views
    return G


def init_or_load_training_models(mixing_fns, encoderes, device, args):
    """
    Initialize or load the training models.

    Args:
        mixing_fns (list): List of mixing functions.
        encoderes (list): List of encoders.
        device (torch.device): The device to use for computation.
        args: Additional arguments.

    Returns:
        dict: A dictionary containing the initialized or loaded models.
    """

    # torch.nn.Module wrapper for encoder-mixing_function composition
    backbone = CompositionEncMix(mixing_fns=mixing_fns, encoders=encoderes)
    backbone.to(device)

    return {"backbone": backbone}


def init_or_load_optimizer(
    models: dict, optimizer_class=torch.optim.Adam, args=None
):
    """
    Initialize or load the optimizer for the models.

    Args:
        models (dict): A dictionary containing the models.
        optimizer_class (torch.optim.Optimizer): The optimizer class to use (default: torch.optim.Adam).
        args (argparse.Namespace): The command-line arguments (default: None).

    Returns:
        tuple: A tuple containing the trainable parameters and the optimizer.
    """
    # initialise trainable parameters
    params = []
    if args.shared_encoder:
        params = models["backbone"].encoders[0].parameters()
    else:
        for g_i in models["backbone"].encoders:
            params = params + list(
                g_i.parameters()
            )  # encoders' parameters are trainable

    """Define Adam optimiser"""
    optimizer = optimizer_class(params, lr=args.lr)
    return params, optimizer


def update_args(args):
    """
    Update the arguments with view-specific latents, subsets, content dictionary, style dictionary,
    content size dictionary, latent dimension, and view-specific content indexing based on the selection.

    Args:
        args (Namespace): The input arguments.

    Returns:
        Namespace: The updated arguments.
    """
    zs_views = torch.tensor(
        args.S_k
    )  # [n_views, n_sk] # the view-specific latents as given in args.

    # retrieve subsets, content dict and style dict for all subsets and views
    if args.subsets is None:
        psets, _ = powerset(range(args.n_views), only_consider_whole_set=False)
        setattr(args, "subsets", psets)

    content_dict, style_dict = content_style_from_subsets(
        subsets=args.subsets, zs=zs_views
    )
    setattr(args, "content_dict", content_dict)
    setattr(args, "style_dict", style_dict)

    # store content size, for the mode: known content size
    content_size_dict = {}
    for k, v in content_dict.items():
        content_size_dict[k] = len(v)
    args.content_size_dict = content_size_dict

    # make sure the number of latents align with Sk
    zn_set = list(set(chain.from_iterable(args.S_k)))
    args.latent_dim = len(zn_set)

    view_specific_content_indexing = {s: {} for s in args.subsets}
    if args.selection == "ground_truth":
        for s in args.subsets:
            for k in s:
                view_specific_content_indexing[s][k] = [
                    args.S_k[k].index(c) for c in args.content_dict[s]
                ]
        args.view_specific_content_indexing = view_specific_content_indexing
    elif args.selection == "concat":  # concat all content indices
        est_content_indices = np.array_split(
            range(args.encoding_size), len(args.subsets)
        )
        args.est_content_dict = {
            subset: {k: indices for k in subset}
            for subset, indices in zip(args.subsets, est_content_indices)
        }
    return args


def parse_args(dargs):
    return argparse.Namespace(**dargs)


## Evaluation ##

device = "cpu"

# if torch.cuda.is_available():
#     device = "cuda:7"
# else:
#     device = "cpu"

save_dir = "path"
# model_path = os.path.join(save_dir, "model.pt")
setting_path = os.path.join(save_dir, "settings.json")
with open(setting_path, "r") as setting_path:
    dargs = json.load(setting_path)
    setting_path.close()
dargs["save_dir"] = save_dir
args = parse_args(dargs)
args = update_args(args)
num_batches = args.num_eval_batches
file_name = "Evaluation"
args.evaluate = True  # Important to set to True here!

latent_space = generate_latent_space(args)
mixing_fns = init_or_load_mixing_fns(
    device, args
)  # mixing function always gives S_k
encoders = init_or_load_encoder_models(
    device,
    args,
    encoding_size=args.encoding_size if args.selection == "concat" else None,
)
models = init_or_load_training_models(
    mixing_fns=mixing_fns, encoderes=encoders, device=device, args=args
)
params, optimizer = init_or_load_optimizer(models=models, args=args)

data_dict, hz_dict, all_zs = generate_data(
    latent_space=latent_space, models=models, num_batches=num_batches, args=args
)

# standardize the estimated latents hz
data_shape = hz_dict[0]["hz"].shape  # [num_batches, batch_size, nSk]
for k, v in hz_dict.items():
    hz_dict[k]["hz"] = (
        StandardScaler()
        .fit_transform(np.concatenate(v["hz"], axis=0))
        .reshape(*data_shape)
    )

k = 0  # view 1
l = 1  # view 2
save_path = args.save_dir
# i = 11  # batch number
for i in range(args.num_eval_batches):
    subset = (0, 1)
    predicted_content_idx = hz_dict[k]["est_c_ind"][subset][i]
    batch_size = hz_dict[k]["hz"][i].shape[0]
    # recovered Z from view 0
    z0_hat0 = np.take_along_axis(
        hz_dict[k]["hz"][i],
        np.tile(predicted_content_idx[None], (batch_size, 1)),
        axis=-1,
    )
    # recovered Z from view 1
    z0_hat1 = np.take_along_axis(
        hz_dict[l]["hz"][i],
        np.tile(predicted_content_idx[None], (batch_size, 1)),
        axis=-1,
    )
    z0_est = np.column_stack([z0_hat0, z0_hat1])
    file_path = os.path.join(save_path, f"z0est_batch{i}.csv")
    np.savetxt(file_path, z0_est, delimiter=",")

    z0 = all_zs[i, :, 0][:, None]
    z1 = all_zs[i, :, 1][:, None]
    z2 = all_zs[i, :, 2][:, None]
    x = all_zs[i, :, 3][:, None]
    y = all_zs[i, :, 4][:, None]
    z_true = np.column_stack([z0, z1, z2, x, y])
    file_path = os.path.join(save_path, f"ztrue_batch{i}.csv")
    np.savetxt(file_path, z_true, delimiter=",")

# ## Load saved data and compare pycomets and comets in R ##

# exper_id = "five_latents"
# batch_num = 0
# df = pd.read_csv(
#     f"~/multiview-crl-eval/results/numerical/{exper_id}/ztrue_batch{batch_num}.csv",
#     header=None,
# )
# df_est = pd.read_csv(
#     f"~/multiview-crl-eval/results/numerical/{exper_id}/z0est_batch{batch_num}.csv",
#     header=None,
# )
# df_all = pd.concat([df, df_est], axis=1)
# df_all.columns = ["z0", "z1", "z2", "x", "y", "z0_est0", "z0_est1"]

# pcm = PCM()
# pcm.test(
#     reg_yonxz=LM(),
#     reg_ronz=LM(),
#     reg_vonxz=LM(),
#     reg_yhatonz=LM(),
#     reg_yonz=LM(),
#     Y=df_all[["z0"]].to_numpy(),
#     X=df_all[["z0_est0"]].to_numpy(),
#     Z=df_all[["z1", "z2", "x", "y"]].to_numpy(),
#     estimate_variance=False,
# )

# pcm.test(
#     reg_yonxz=LM(),
#     reg_ronz=LM(),
#     reg_vonxz=LM(),
#     reg_yhatonz=LM(),
#     reg_yonz=LM(),
#     Y=df_all[["z1"]].to_numpy(),
#     X=df_all[["z0_est0"]].to_numpy(),
#     Z=df_all[["z0", "z2", "x", "y"]].to_numpy(),
#     estimate_variance=False,
# )
# pcm.test(
#     reg_yonxz=LM(),
#     reg_ronz=LM(),
#     reg_vonxz=LM(),
#     reg_yhatonz=LM(),
#     reg_yonz=LM(),
#     Y=df_all[["z2"]].to_numpy(),
#     X=df_all[["z0_est0"]].to_numpy(),
#     Z=df_all[["z0", "z1", "x", "y"]].to_numpy(),
#     estimate_variance=False,
# )


# def tmex(df_all, alpha=0.05):
#     pcm = PCM()
#     pcm.test(
#         reg_yonxz=LM(),
#         reg_ronz=LM(),
#         reg_vonxz=LM(),
#         reg_yhatonz=LM(),
#         reg_yonz=LM(),
#         Y=df_all[["z0"]].to_numpy(),
#         X=df_all[["z0_est0"]].to_numpy(),
#         Z=df_all[["z1", "z2", "x", "y"]].to_numpy(),
#         estimate_variance=False,
#         rep=5,
#     )
#     pval1 = pcm.pval
#     pcm.test(
#         reg_yonxz=LM(),
#         reg_ronz=LM(),
#         reg_vonxz=LM(),
#         reg_yhatonz=LM(),
#         reg_yonz=LM(),
#         Y=df_all[["z1"]].to_numpy(),
#         X=df_all[["z0_est0"]].to_numpy(),
#         Z=df_all[["z0", "z2", "x", "y"]].to_numpy(),
#         estimate_variance=False,
#         rep=5,
#     )
#     pval2 = pcm.pval
#     pcm.test(
#         reg_yonxz=LM(),
#         reg_ronz=LM(),
#         reg_vonxz=LM(),
#         reg_yhatonz=LM(),
#         reg_yonz=LM(),
#         Y=df_all[["z2"]].to_numpy(),
#         X=df_all[["z0_est0"]].to_numpy(),
#         Z=df_all[["z0", "z1", "x", "y"]].to_numpy(),
#         estimate_variance=False,
#         rep=5,
#     )
#     pval3 = pcm.pval
#     score1 = int(pval1 < alpha) - 1
#     score2 = int(pval2 < alpha)
#     score3 = int(pval3 < alpha)
#     tmex_score = (score1 + score2 + score3) / 3
#     return tmex_score, pval1, pval2, pval3


# tmex(df_all, alpha=0.05)

# ## OLD COLD below ##

# gcm = GCM()
# gcm.test(reg_yz=LM(), reg_xz=LM(), Y=z0_hat0, X=z1, Z=z0)
# print(f'gcm resid cor view 0: {gcm.get_cor(type="pearson")}')
# plt.cla()
# plt.scatter(gcm.rX, gcm.rY)
# plt.savefig("tmp_gcm0.jpeg")

# gcm.test(reg_yz=LM(), reg_xz=LM(), Y=z0_hat1, X=z2, Z=z0)
# print(f'gcm resid cor view 1: {gcm.get_cor(type="pearson")}')
# plt.cla()
# plt.scatter(gcm.rX, gcm.rY)
# plt.savefig("tmp_gcm1.jpeg")

# gcm.test(reg_yz=LM(), reg_xz=LM(), Y=z0_hat0, X=z2, Z=z0)
# print(gcm.get_cor(type="pearson"))
# plt.cla()
# plt.scatter(gcm.rX, gcm.rY)
# plt.savefig("tmp.jpeg")

# gcm.test(reg_yz=LM(), reg_xz=LM(), Y=z0_hat1, X=z1, Z=z0)
# print(gcm.get_cor(type="pearson"))
# plt.cla()
# plt.scatter(gcm.rX, gcm.rY)
# plt.savefig("tmp.jpeg")

# # (this run through all batches and compute the mean and std of the scores)
# # predict individual latents from the estimated content block
# for subset_idx, subset in enumerate(data_dict):
#     scores = {
#         latent_idx: {"linear": [], "nonlinear": []}
#         for latent_idx in range(args.latent_dim)
#     }
#     for k in subset:
#         for i in range(num_batches):
#             predicted_content_idx = hz_dict[k]["est_c_ind"][subset][i]
#             batch_size = hz_dict[k]["hz"][i].shape[0]
#             inputs = np.take_along_axis(
#                 hz_dict[k]["hz"][i],
#                 np.tile(predicted_content_idx[None], (batch_size, 1)),
#                 axis=-1,
#             )
#             for latent_idx in range(args.latent_dim):
#                 # labels = StandardScaler().fit_transform(data_dict[subset][k][keyword])
#                 labels = all_zs[i, :, latent_idx][
#                     :, None
#                 ]  # [batch_size, n_keyword]
#                 (
#                     train_inputs,
#                     test_inputs,
#                     train_labels,
#                     test_labels,
#                 ) = train_test_split(
#                     labels, inputs
#                 )  # train_test_split(inputs, labels)
#                 data = [train_inputs, train_labels, test_inputs, test_labels]
#                 r2_linear = evaluate_prediction(
#                     linear_model.LinearRegression(n_jobs=-1), r2_score, *data
#                 )
#                 if args.evaluate:
#                     # nonlinear regression
#                     r2_nonlinear = evaluate_prediction(
#                         generate_nonlinear_model(), r2_score, *data
#                     )
#                 else:
#                     r2_nonlinear = -1.0  # not computed
#                 scores[latent_idx]["linear"].append(r2_linear)
#                 scores[latent_idx]["nonlinear"].append(r2_nonlinear)
#         for latent_idx in range(args.latent_dim):
#             file_path = os.path.join(
#                 args.save_dir, f"{file_name}_label_to_input.csv"
#             )
#             fileobj = open(file_path, "a+")
#             writer = csv.writer(fileobj)
#             wri = [
#                 subset,
#                 "view",
#                 k,
#                 "latent_idx",
#                 latent_idx,
#                 "linear mean",
#                 f"{np.mean(scores[latent_idx]['linear']):.3f} +- {np.std(scores[latent_idx]['linear']) :.3f}",
#                 "nonlinear mean",
#                 f"{np.mean(scores[latent_idx]['nonlinear']):.3f} +- {np.std(scores[latent_idx]['nonlinear']):.3f}",
#             ]
#             writer.writerow(wri)
#             fileobj.close()
