import argparse
import os
from functools import partial
from pathlib import Path
from typing import List

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from causica.datasets.causica_dataset_format import Variable
from causica.datasets.tensordict_utils import expand_td_with_batch_shape
from causica.distributions import ContinuousNoiseDist
from causica.lightning.data_modules.basic_data_module import BasicDECIDataModule
from causica.lightning.modules.deci_module import DECIModule
from causica.sem.distribution_parameters_sem import DistributionParametersSEM
from causica.training.auglag import AugLagLRConfig
from pytorch_lightning.callbacks import TQDMProgressBar
from scipy.special import logsumexp
from tensordict import TensorDict
from tqdm import tqdm

from CITNP.utils.processing import normalise_variable

test_run = bool(
    os.environ.get("TEST_RUN", False)
)  # used by testing to run the notebook as a script


def load_data(data_path: Path):
    with h5py.File(data_path, "r") as f:
        obs_data = f["obs_data"][:]
        int_data = f["int_data"][:]
        causal_graphs = f["causal_graphs"][:]
        intvn_indices = f["intvn_indices"][:]
        variable_counts = f["variable_counts"][:]

    data_dict = {
        "obs_data": obs_data,
        "int_data": int_data,
        "causal_graphs": causal_graphs,
        "intvn_indices": intvn_indices,
        "variable_counts": variable_counts,
    }
    return data_dict


def convert_data_to_pd(data: np.ndarray):
    """
    Converts the data to a pandas dataframe.
    """
    # Set header
    header = [f"X{i}" for i in range(data.shape[1])]
    obs_data = pd.DataFrame(
        data,
        columns=header,
    )
    return obs_data


def per_variable_log_prob(value: TensorDict, do_sem: DistributionParametersSEM):
    """
    Computes the log probability of each variable in the SEM.
    """
    expanded_value = expand_td_with_batch_shape(value, do_sem.batch_shape)
    noise_object = do_sem.noise_dist(do_sem.func(expanded_value, do_sem.graph))
    log_probs = {
        name: noise_dist.log_prob(value[name])
        for name, noise_dist in noise_object._independent_noise_dists.items()
    }
    return log_probs


def append_to_dict(dict_to_append, dict):
    for key, value in dict.items():
        if key in dict_to_append:
            dict_to_append[key].append(value)
        else:
            dict_to_append[key] = [value]


def apply_to_dict(func, dict):
    return {key: func(tensors) for key, tensors in dict.items()}


def calculate_neg_log_prob(log_prob):
    graph_sample_size = log_prob.shape[0]
    all_sample_log_prob = log_prob.sum(1)
    log_sum_exp = logsumexp(all_sample_log_prob, axis=0) - np.log(graph_sample_size)

    return -1 * log_sum_exp


def calculate_intvn_dist(
    data_module: BasicDECIDataModule,
    sem_list: List[DistributionParametersSEM],
    intvn_indices: int,
    int_data: np.ndarray,
    int_data_dict: TensorDict,
):
    int_data = torch.from_numpy(int_data).to(torch.float32)
    int_values = int_data[:, intvn_indices]
    outcome_idx = 1 if intvn_indices == 0 else 0
    outcome_node = int_data[:, outcome_idx]

    if hasattr(data_module, "normalizer"):
        normalizer = data_module.normalizer
    else:
        normalizer = None

    all_scm_log_prob = {}
    all_samples = np.zeros((10, outcome_node.shape[0]))
    for sem_idx, sem in tqdm(enumerate(sem_list), total=len(sem_list)):
        all_var_log_prob = {}
        for idx, curr_int in enumerate(int_values):

            int_dict = TensorDict(
                {f"X{intvn_indices.item()}": curr_int}, batch_size=tuple()
            )
            do_sem = sem.do(
                interventions=int_dict,
            )

            if sem_idx < 10:
                samples = do_sem.sample(torch.Size([1]))
                if normalizer is not None:
                    outcome_samples = normalizer.inv(samples)[f"X{outcome_idx}"]
                else:
                    outcome_samples = samples[f"X{outcome_idx}"]

                all_samples[sem_idx, idx] = outcome_samples
            # Gives out log prob of var in order of 0, 1, 2
            # but skips the intventioned variable
            # So we can safely take the first one
            var_log_prob = per_variable_log_prob(int_data_dict[idx : idx + 1], do_sem)
            all_var_log_prob

            append_to_dict(all_var_log_prob, var_log_prob)

        final_dict = apply_to_dict(partial(np.concatenate, axis=0), all_var_log_prob)

        append_to_dict(all_scm_log_prob, final_dict)

    all_scm_log_prob = apply_to_dict(partial(np.stack, axis=0), all_scm_log_prob)

    neg_log_prob_dict = apply_to_dict(calculate_neg_log_prob, all_scm_log_prob)
    return neg_log_prob_dict, all_samples


def main(
    data_path: Path,
    data_path_name: str,
    sample_size: int = 500,
    data_start: int = 0,
    data_end: int = 100,
):

    result_folder = (
        Path("CausalInferenceNeuralProcess/baselines")
        / "deci"
        / "results"
        / data_path_name
    )
    result_folder.mkdir(parents=True, exist_ok=True)

    data = load_data(data_path)

    all_obs_data = data["obs_data"][:, :sample_size, :]
    all_int_data = data["int_data"][:, :sample_size, :]
    all_intvn_indices = data["intvn_indices"]

    for i in range(1):
        pl.seed_everything(seed=1)  # set the random seed

        obs_data = all_obs_data[i]
        int_data = all_int_data[i]
        intvn_indices = all_intvn_indices[i]

        data_pd = convert_data_to_pd(obs_data)
        data_module = BasicDECIDataModule(
            dataframe=data_pd,
            normalize=False,
            batch_size=128,
            variables=[
                Variable(group_name=f"X{i}", name=f"X{i}")
                for i in range(obs_data.shape[1])
            ],
        )

        int_data_pd = convert_data_to_pd(int_data)
        int_data_module = BasicDECIDataModule(
            dataframe=int_data_pd,
            normalize=False,
            batch_size=128,
            variables=[
                Variable(group_name=f"X{i}", name=f"X{i}")
                for i in range(int_data.shape[1])
            ],
        )

        num_nodes = len(data_module.dataset_train.keys())

        lightning_module = DECIModule(
            noise_dist=ContinuousNoiseDist.SPLINE,
            prior_sparsity_lambda=43.0,
            init_rho=30.0,
            init_alpha=0.20,
            auglag_config=AugLagLRConfig(
                max_inner_steps=3400,
                max_outer_steps=8,
                lr_init_dict={
                    "icgnn": 0.00076,
                    "vardist": 0.0098,
                    "functional_relationships": 3e-4,
                    "noise_dist": 0.0070,
                },
            ),
        )

        trainer = pl.Trainer(
            accelerator="auto",
            max_epochs=2000,
            fast_dev_run=test_run,
            callbacks=[TQDMProgressBar(refresh_rate=19)],
            enable_checkpointing=False,
        )
        trainer.fit(lightning_module, datamodule=data_module)

        sem_module = lightning_module.sem_module
        sem_dist = sem_module()
        sem_samples = sem_dist.sample(sample_shape=torch.Size([100]))

        with torch.no_grad():
            average_neg_log_prob, all_samples = calculate_intvn_dist(
                data_module=data_module,
                sem_list=sem_samples,
                intvn_indices=intvn_indices,
                int_data=int_data,
                int_data_dict=int_data_module.dataset_train,
            )

        # Plot the interventional data and samples
        outcome_idx = 1 if intvn_indices == 0 else 0
        plt.scatter(
            int_data[:, intvn_indices], int_data[:, outcome_idx], label="outcome"
        )
        for sample_idx in range(all_samples.shape[0]):
            if sample_idx == 0:
                label = "sample"
            else:
                label = None
            plt.scatter(
                int_data[:, intvn_indices],
                all_samples[sample_idx, :],
                label=label,
                alpha=0.05,
                color="red",
            )
            plt.legend()
        plt.savefig(
            f"CausalInferenceNeuralProcess/baselines/deci/sachs_{data_start}.png"
        )
        plt.close()

        # Save the results
        save_name = data_path_name + f"_deci_nlpd_data{data_start}.csv"
        save_path = result_folder / save_name
        df = pd.DataFrame([average_neg_log_prob])
        df.to_csv(save_path, index=False)

        print(f"Average log prob for dataset {i}: {average_neg_log_prob}")


if __name__ == "__main__":
    data_path_name = "sachs"

    all_data_names = [
        "data_idx0",
        "data_idx1",
        "data_idx2",
        "data_idx3",
        "data_idx4",
    ]

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_start",
        type=int,
        default=0,
        help="data_start",
    )
    parser.add_argument(
        "--data_end",
        type=int,
        default=5,
        help="data_end",
    )
    args = parser.parse_args()

    data_name = all_data_names[args.data_start]

    data_path = Path(
        f"CausalInferenceNeuralProcess/CITNP/datasets/synth_training_data/{data_path_name}/test/{data_name}.hdf5"
    )

    sample_size = 500

    main(
        data_path=data_path,
        data_path_name=data_path_name,
        sample_size=sample_size,
        data_start=args.data_start,
        data_end=args.data_end,
    )
