from functools import partial

import os
import pickle
import yaml

import logging

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from xgboost import XGBRegressor

import hydra
from omegaconf import DictConfig, OmegaConf
from utils.hydra import register_resolvers, get_output_dir

from dcg.distributional.discrete import Categorical
from dcg.distributional.continuous import Normal, Beta, Exponential
from dcg.modules import PredictorDCU
from dcg.flow import NCF
from dcg.latents import Normal as LatentNormal
from dcg.graph import CausalGraph
from dcg.training import (
    data_loader, TensorDataset, train, loss_f as _loss_f,
    test_loglk, plot_losses
)

from cnf import create_cnf, cnf_loss, cnf_shap

from do_shap import factory
from do_shap.shap import shap, dgp_shap, marginal_shap, dcg_shap
from do_shap.frontiers import FR1, FR2, DAG, Edge

from utils.file import try_mkdir
from utils.ds import train_val_test_split
from utils.density import density_plots

from dgp import DGP


log = logging.getLogger(__name__)
register_resolvers()  # register some custom resolvers for hydra

DGP_DICT = dict(dgp=DGP)

NET_DICT = dict(
    linear=factory.linear,
    ff=factory.ff,
)


@hydra.main(
    version_base=None,
    config_path=os.getenv('CONF_DIR'),
    config_name='markovian'
)
def main(cfg: DictConfig) -> None:
    try_mkdir(results_folder := get_output_dir())

    # Store config, but check if it's been run beforehand
    if os.path.exists(
        config_path := os.path.join(results_folder, 'conf.yaml')
    ) and not cfg.overwrite:
        raise FileExistsError(
            f'Experiment already run but overwrite=False: {config_path}'
        )

    with open(config_path, 'w') as f:
        f.write(OmegaConf.to_yaml(cfg))

    # Create folders for all outputs
    try_mkdir(data_folder := os.path.join(results_folder, 'data'))
    try_mkdir(metrics_folder := os.path.join(results_folder, 'metrics'))
    try_mkdir(models_folder := os.path.join(results_folder, 'models'))
    try_mkdir(plots_folder := os.path.join(results_folder, 'plots'))

    # Begin experiment
    # Set random seeds
    np.random.seed(cfg.random_seed)
    torch.random.manual_seed(cfg.random_seed + 1)

    # Select which DGP to use
    DGP = DGP_DICT[cfg.dgp.which]

    # For this kind of data, it's preferable to work with CPU rather than MPS
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define training loss
    loss_f = _loss_f(ex_n=cfg.train.loss_mc_n)

    # Generate data
    rs = np.random.RandomState(seed=cfg.random_seed + 2)
    df: pd.DataFrame = DGP(cfg.dgp.n_samples, latents=cfg.dgp.latents, rs=rs)
    X = torch.Tensor(df.values.astype(float))

    df.to_csv(os.path.join(data_folder, 'df.csv'), index=False)

    # Split data in train/validation/test
    idx_train, idx_val, idx_test, Xtrain, Xval, Xtest = train_val_test_split(
        X,
        train_size=cfg.train_size,
        test_size=cfg.test_size,
        rs=rs,
        return_index=True,
    )

    # Save idxs for later
    with open(os.path.join(data_folder, 'idxs.pkl'), 'wb') as f:
        pickle.dump(dict(
            idx_train=idx_train,
            idx_val=idx_val,
            idx_test=idx_test,
        ), f)

    # Train the model to explain
    V: list[str] = list(cfg.V)  # all variables
    V_model: list[str] = list(cfg.V_model)  # should be ['z', 'x', 'c']
    train_data = df.iloc[np.concatenate([idx_train, idx_val])].copy()
    test_data = df.iloc[np.concatenate([idx_test])].copy()

    reg = XGBRegressor().fit(
        train_data[V_model],
        train_data.y
    )

    # Compute and save predictions
    pred = reg.predict(df[V_model])
    df2 = df.copy()
    df2['pred'] = pred
    df2.to_csv(os.path.join(data_folder, 'df.csv'), index=False)
    
    # Define FRA
    fra_dag = DAG(len(V) + 1)
    for edge in zip(cfg.fra_definition[::2], cfg.fra_definition[1::2]):
        fra_dag.add_edge(Edge(*edge))
    fra = FR1(fra_dag)

    # Define SHAP kwargs
    shap_kwargs = dict(filter(
        lambda t: t[0] in ['max_perms', 'adaptive', 'min_perms', 'cache_size'],
        cfg.shap.items()
    ))
    shap_kwargs['cache_redirect'] = fra.cache_redirect

    def predict_filter(reg, X, *, V_model):
        return reg.predict(X[V_model])

    # Begin computing SHAP values
    if cfg.shap.dgp.run:  # Data Generating Process
        log.info('Computing DGP-SHAP')

        # Compute SHAP exactly, since this will be our ground truth
        dgp_results = dgp_shap(
            test_data, V, 'y', partial(DGP, latents=cfg.dgp.latents),
            partial(predict_filter, reg, V_model=V_model),
            N=cfg.shap.dgp.n_samples, rs=rs,
            max_perms=..., adaptive=False,
            cache_redirect=fra.cache_redirect,
        )

        with open(os.path.join(metrics_folder, 'dgp_shap.pkl'), 'wb') as f:
            pickle.dump(dict(
                mean=dgp_results[0],
                std=dgp_results[1]
            ), f)

        fra.clear()  # reset for the next execution

    if cfg.shap.marginal.run:  # Use Marginal interventions with a predictor
        log.info('Computing Marginal-SHAP')

        marginal_results = marginal_shap(
            test_data, V, 'y',
            partial(predict_filter, reg, V_model=V_model),
            train_data, rs=rs, **shap_kwargs
        )

        with open(os.path.join(metrics_folder, 'marginal_shap.pkl'), 'wb') as f:
            pickle.dump(dict(
                mean=marginal_results[0],
                std=marginal_results[1]
            ), f)
    
        fra.clear()  # reset for the next execution


    def predictor_dcu_f(X: torch.Tensor) -> torch.Tensor:
        # Note! In this particular case we need to replace the one-hot b
        # to a single variable because of how we've defined the regressor.
        res = reg.predict(X.cpu().numpy())

        return torch.Tensor(res, device=X.device).unsqueeze(1)


    def compute_dcg_density(
        graph: CausalGraph, col: str, x: np.ndarray
    ) -> np.ndarray:
        with torch.no_grad():
            return np.exp(graph.loglk(
                {col: torch.Tensor(x[:, np.newaxis], device=graph.device)}
            ).cpu().numpy())


    # Deep Causal Graph with linear cond. and Normal distr. for cont. variables
    if cfg.shap.dcg_linear.run:
        # region DCG creation
        # Load graph definition
        # Definition format: name type dimension [parents, ...]
        with open(cfg.definition) as f:
            definition = f.read()

        # Create graph
        graph = CausalGraph.from_definition(
            CausalGraph.parse_definition(
                # Here we link variable types to their classes
                definition,
                lat=LatentNormal,
                cat=Categorical,
                cont=Normal,
                norm=Normal,
                flow=Normal,
                flow01=Normal,
                flow_pos=Normal,
                target=partial(
                    PredictorDCU, discrete=False, predict=predictor_dcu_f
                )
            ),
            net_f=partial(factory.linear, cfg=cfg.net)
        ).warm_start(Xtrain).to(device)
        # endregion

        # region Train (or load) DCG
        graph_path = os.path.join(models_folder, 'dcg_linear_model.pt')
        if not cfg.shap.dcg_linear.load:
            log.info('Training DCG-linear')
            losses = train(
                graph,
                data_loader(
                    TensorDataset(Xtrain), cfg.train.batch_size, drop_last=True
                ),
                data_loader(
                    TensorDataset(Xval), cfg.train.batch_size, drop_last=False
                ),
                loss_f,  # avg negative log-likelihood

                # Since we're using a linear model, this is enough
                optimizer_kwargs=dict(
                    lr=cfg.train.lr,
                    weight_decay=cfg.train.weight_decay),
                n_epochs=cfg.train.n_epochs,
                patience=cfg.train.patience,

                use_tqdm=True,  # don't show the progress bar
                silent=False
            )

            # Save losses
            torch.save(losses, os.path.join(metrics_folder, 'dcg_linear_losses.pt'))

            # Save graph
            torch.save(graph.state_dict(), graph_path)

            # Plot losses
            plot_losses(*losses)
            plt.savefig(os.path.join(plots_folder, 'dcg_linear_losses.png'), dpi=300)

            # Compute loglks
            metrics = {}
            for split, data in zip(['train', 'val', 'test'], [Xtrain, Xval, Xtest]):
                metrics[split + '_loglk'] = res = float(test_loglk(graph, data))
                log.info(f'{split} loglk: {res}')

            with open(
                os.path.join(metrics_folder, 'dcg_linear_losses.yaml'), 'w'
            ) as f:
                yaml.safe_dump(metrics, f, sort_keys=False)
        
            # Plot Scattermatrix with densities
            # with torch.no_grad():
            #     sample = pd.DataFrame(
            #         graph.sample(1000).cpu().numpy(),
            #         columns=df.columns
            #     ).drop('y', axis=1)

            # density_plots(
            #     df.drop('y', axis=1),
            #     {'DCG-linear': sample},
            #     {'DCG-linear': partial(compute_dcg_density, graph)},
            #     hist_kwargs=dict(bins=20),
            #     contour_kwargs=dict(alpha=.5)
            # )

            # plt.savefig(
            #     os.path.join(plots_folder, 'scatter_dcg_linear.png'), dpi=300
            # )
        else:
            # Just load the graph
            graph.load_state_dict(torch.load(graph_path))
            graph.eval()
        # endregion

        # Compute SHAP
        log.info('Computing DCG-linear-SHAP')
        dcg_linear_results = dcg_shap(
            test_data.values,
            V, 'y',
            graph,
            mc_n=cfg.shap.n_samples,
            rs=rs,
            **shap_kwargs
        )

        with open(
            os.path.join(metrics_folder, 'dcg_linear_shap.pkl'), 'wb'
        ) as f:
            pickle.dump(dict(
                mean=dcg_linear_results[0],
                std=dcg_linear_results[1]
            ), f)

        fra.clear()  # reset for the next execution


    if cfg.shap.dcn.run:  # Distributional Causal Nodes
        # region DCG creation
        # Load graph definition
        # Definition format: name type dimension [parents, ...]
        with open(cfg.definition) as f:
            definition = f.read()

        # Create graph
        graph = CausalGraph.from_definition(
            CausalGraph.parse_definition(
                # Here we link variable types to their classes
                definition,
                lat=LatentNormal,
                cat=Categorical,
                cont=Normal,
                norm=Normal,
                flow=Normal,
                flow01=Beta,
                flow_pos=Exponential,
                target=partial(
                    PredictorDCU, discrete=False, predict=predictor_dcu_f
                )
            ),
            net_f=partial(NET_DICT[cfg.net.kind], cfg=cfg.net)
        ).warm_start(Xtrain).to(device)
        # endregion

        # region Train (or load) DCG
        graph_path = os.path.join(models_folder, 'dcn_model.pt')
        if not cfg.shap.dcn.load:
            log.info('Training DCN')
            losses = train(
                graph,
                data_loader(
                    TensorDataset(Xtrain), cfg.train.batch_size, drop_last=True
                ),
                data_loader(
                    TensorDataset(Xval), cfg.train.batch_size, drop_last=False
                ),
                loss_f,  # avg negative log-likelihood

                # Since we're using a linear model, this is enough
                optimizer_kwargs=dict(
                    lr=cfg.train.lr,
                    weight_decay=cfg.train.weight_decay),
                n_epochs=cfg.train.n_epochs,
                patience=cfg.train.patience,

                use_tqdm=True,  # don't show the progress bar
                silent=False
            )

            # Save losses
            torch.save(losses, os.path.join(metrics_folder, 'dcn_losses.pt'))

            # Save graph
            torch.save(graph.state_dict(), graph_path)

            # Plot losses
            plot_losses(*losses)
            plt.savefig(os.path.join(plots_folder, 'dcn_losses.png'), dpi=300)

            # Compute loglks
            metrics = {}
            for split, data in zip(['train', 'val', 'test'], [Xtrain, Xval, Xtest]):
                metrics[split + '_loglk'] = res = float(test_loglk(graph, data))
                log.info(f'{split} loglk: {res}')

            with open(
                os.path.join(metrics_folder, 'dcn_losses.yaml'), 'w'
            ) as f:
                yaml.safe_dump(metrics, f, sort_keys=False)

            # Plot Scattermatrix with densities
            # with torch.no_grad():
            #     sample = pd.DataFrame(
            #         graph.sample(1000).cpu().numpy(),
            #         columns=df.columns
            #     ).drop('y', axis=1)

            # density_plots(
            #     df.drop('y', axis=1),
            #     {'DCN': sample},
            #     {'DCN': partial(compute_dcg_density, graph)},
            #     hist_kwargs=dict(bins=20),
            #     contour_kwargs=dict(alpha=.5)
            # )

            # plt.savefig(
            #     os.path.join(plots_folder, 'scatter_dcn.png'), dpi=300
            # )
        else:
            # Just load the graph
            graph.load_state_dict(torch.load(graph_path))
            graph.eval()
        # endregion

        # Compute SHAP
        log.info('Computing DCN-SHAP')
        dcn_results = dcg_shap(
            test_data.values,
            V, 'y',
            graph,
            mc_n=cfg.shap.n_samples,
            rs=rs,
            **shap_kwargs
        )

        with open(os.path.join(metrics_folder, 'dcn_shap.pkl'), 'wb') as f:
            pickle.dump(dict(
                mean=dcn_results[0],
                std=dcn_results[1]
            ), f)

        fra.clear()  # reset for the next execution

    if cfg.shap.dcg.run:  # Deep Causal Graph
        # region DCG creation
        # Load graph definition
        # Definition format: name type dimension [parents, ...]
        with open(cfg.definition) as f:
            definition = f.read()

        # Create graph
        graph = CausalGraph.from_definition(
            CausalGraph.parse_definition(
                # Here we link variable types to their classes
                definition,
                lat=LatentNormal,
                cat=Categorical,
                cont=partial(NCF, flow_f=partial(
                    factory.flow, cfg=cfg.flow
                )),
                norm=Normal,
                flow=partial(NCF, flow_f=partial(
                    factory.flow, cfg=cfg.flow
                )),
                flow01=partial(NCF, flow_f=partial(
                    factory.flow01, cfg=cfg.flow
                )),
                flow_pos=partial(NCF, flow_f=partial(
                    factory.flow_pos, cfg=cfg.flow,
                )),
                target=partial(
                    PredictorDCU, discrete=False, predict=predictor_dcu_f
                )
            ),
            net_f=partial(NET_DICT[cfg.net.kind], cfg=cfg.net)
        ).warm_start(Xtrain).to(device)
        # endregion

        # region Train (or load) DCG
        graph_path = os.path.join(models_folder, 'dcg_model.pt')
        if not cfg.shap.dcg.load:
            log.info('Training DCG')
            losses = train(
                graph,
                data_loader(
                    TensorDataset(Xtrain), cfg.train.batch_size, drop_last=True
                ),
                data_loader(
                    TensorDataset(Xval), cfg.train.batch_size, drop_last=False
                ),
                loss_f,  # avg negative log-likelihood

                # Since we're using a linear model, this is enough
                optimizer_kwargs=dict(
                    lr=cfg.train.lr,
                    weight_decay=cfg.train.weight_decay),
                n_epochs=cfg.train.n_epochs,
                patience=cfg.train.patience,

                use_tqdm=True,  # don't show the progress bar
                silent=False
            )

            # Save losses
            torch.save(losses, os.path.join(metrics_folder, 'dcg_losses.pt'))

            # Save graph
            torch.save(graph.state_dict(), graph_path)

            # Plot losses
            plot_losses(*losses)
            plt.savefig(os.path.join(plots_folder, 'dcg_losses.png'), dpi=300)

            # Compute loglks
            metrics = {}
            for split, data in zip(['train', 'val', 'test'], [Xtrain, Xval, Xtest]):
                metrics[split + '_loglk'] = res = float(test_loglk(graph, data))
                log.info(f'{split} loglk: {res}')

            with open(
                os.path.join(metrics_folder, 'dcg_losses.yaml'), 'w'
            ) as f:
                yaml.safe_dump(metrics, f, sort_keys=False)

            # Plot Scattermatrix with densities
            # with torch.no_grad():
            #     sample = pd.DataFrame(
            #         graph.sample(1000).cpu().numpy(),
            #         columns=df.columns
            #     ).drop('y', axis=1)

            # density_plots(
            #     df.drop('y', axis=1),
            #     {'DCG': sample},
            #     {'DCG': partial(compute_dcg_density, graph)},
            #     hist_kwargs=dict(bins=20),
            #     contour_kwargs=dict(alpha=.5)
            # )

            # plt.savefig(
            #     os.path.join(plots_folder, 'scatter_dcg.png'), dpi=300
            # )
        else:
            # Just load the graph
            graph.load_state_dict(torch.load(graph_path))
            graph.eval()
        # endregion

        # Compute SHAP
        log.info('Computing DCG-SHAP')
        dcg_results = dcg_shap(
            test_data.values,
            V, 'y',
            graph,
            mc_n=cfg.shap.n_samples,
            rs=rs,
            **shap_kwargs
        )

        with open(os.path.join(metrics_folder, 'dcg_shap.pkl'), 'wb') as f:
            pickle.dump(dict(
                mean=dcg_results[0],
                std=dcg_results[1]
            ), f)

        fra.clear()  # reset for the next execution


    if cfg.shap.cnf.run:  # Causal Normalizing Flows
        A = torch.Tensor([  # aij = 1 <-> x_j -> x_i
            #U, Z, X, A, B, C
            [0, 0, 0, 0, 0, 0],  # U
            [0, 0, 0, 0, 0, 0],  # Z
            [1, 1, 0, 0, 0, 0],  # X
            [0, 0, 1, 0, 0, 0],  # A
            [1, 0, 0, 1, 0, 0],  # B
            [0, 0, 0, 0, 1, 0],  # C
        ])

        cnf = create_cnf(X[:, :-1], A, cfg.cnf)  # discard y

        # region Train (or load) CNF
        cnf_path = os.path.join(models_folder, 'cnf_model.pt')
        if not cfg.shap.cnf.load:
            log.info('Training CNF')
            losses = train(
                cnf,
                data_loader(
                    TensorDataset(Xtrain[:, :-1]), cfg.train.batch_size, drop_last=True
                ),
                data_loader(
                    TensorDataset(Xval[:, :-1]), cfg.train.batch_size, drop_last=False
                ),
                cnf_loss,  # avg negative log-likelihood

                # Since we're using a linear model, this is enough
                optimizer_kwargs=dict(
                    lr=cfg.train.lr,
                    weight_decay=cfg.train.weight_decay),
                n_epochs=cfg.train.n_epochs,
                patience=cfg.train.patience,

                use_tqdm=True,  # don't show the progress bar
                silent=False
            )

            # Save losses
            torch.save(losses, os.path.join(metrics_folder, 'cnf_losses.pt'))

            # Save cnf
            torch.save(cnf.state_dict(), cnf_path)

            # Plot losses
            plot_losses(*losses)
            plt.savefig(os.path.join(plots_folder, 'cnf_losses.png'), dpi=300)

            # Compute loglks
            metrics = {}
            for split, data in zip(['train', 'val', 'test'], [Xtrain, Xval, Xtest]):
                metrics[split + '_loglk'] = res = float(cnf.log_prob(data[:, :-1]).mean())
                log.info(f'{split} loglk: {res}')

            with open(
                os.path.join(metrics_folder, 'cnf_losses.yaml'), 'w'
            ) as f:
                yaml.safe_dump(metrics, f, sort_keys=False)

            # Plot Scattermatrix with densities
            # with torch.no_grad():
            #     sample = pd.DataFrame(
            #         cnf.sample((1000,))['x_obs'].cpu().numpy(),
            #         columns=df.columns[:-1]  # without y
            #     )

            # density_plots(
            #     df.drop('y', axis=1),
            #     {'CNF': sample},
            #     {'CNF': partial(
            #         compute_cnf_density, cnf,
            #         df_train=train_data.drop('y', axis=1)
            #     )},
            #     hist_kwargs=dict(bins=20),
            #     contour_kwargs=dict(alpha=.5)
            # )

            # plt.savefig(
            #     os.path.join(plots_folder, 'scatter_cnf.png'), dpi=300
            # )
        else:
            # Just load the cnf
            cnf.load_state_dict(torch.load(cnf_path))
            cnf.eval()
        # endregion

        # Compute SHAP
        log.info('Computing CNF-SHAP')
        if cfg.shap.cnf.use_mps:
            # In this case, since we'll be computing big samples
            # it's better to work in MPS (even with partial support)
            # as in CPU it will take too long
            cnf = cnf.to('mps')

        cnf_results = cnf_shap(
            test_data,
            V,
            partial(predict_filter, reg, V_model=V_model),
            cnf,
            df.drop('y', axis=1).columns,
            mc_n=cfg.shap.n_samples,
            rs=rs,
            use_mps=cfg.shap.cnf.use_mps,
            **shap_kwargs
        )

        with open(os.path.join(metrics_folder, 'cnf_shap.pkl'), 'wb') as f:
            pickle.dump(dict(
                mean=cnf_results[0],
                std=cnf_results[1]
            ), f)

        fra.clear()  # reset for the next execution


if __name__ == '__main__':
    main()
