from typing import Any, cast
import os
from random import seed
from time import time
import yaml

import numpy as np
import pandas as pd
import torch
from torch import optim

from dcg.graph import CausalGraph
from dcg.node import CausalNode
from dcg.distributional.continuous import Normal
from dcg.training import *
from do_shap.factory import linear
from do_shap.shap import (
    shap, MeanStd, NODE_LIKE, _cache_redirect, cache_ratio_find_N
)
from scripts.frontiers import sample_dag_all_ancestors_not_all_parents, FR2


class DGP:

    SCALE = 1

    def __init__(
        self,
        G: dict[str, list[str]]
    ):
        self.G: dict[str, list[str]] = G
        self.nodes: list[str] = list(G.keys())

    @property
    def K(self) -> int:
        return len(self.G)

    def _parents(self, node: str) -> list[str]:
        return self.G[node]

    def _sample(
        self, n: int, parents: list[np.ndarray], *,
        rs: np.random.RandomState
    ) -> np.ndarray:
        if parents:
            value = np.mean(parents, axis=0)
            assert value.shape == (n,)
        else:
            value = 0

        return rs.normal(loc=value, scale=self.SCALE, size=(n,))
    
    def __call__(
        self, n: int, rs: np.random.RandomState,
        **intv: np.ndarray,
    ) -> pd.DataFrame:
        for node in self.nodes:
            if node not in intv:
                intv[node] = self._sample(n, [
                    intv[parent]
                    for parent in self._parents(node)
                ], rs=rs)
                
        return pd.DataFrame(np.stack(
            [intv[node] for node in self.nodes], axis=1
        ), columns=self.nodes)


# Replicate dcg_shap here so as to include a timer inside the value function
def dcg_shap(
    x: np.ndarray,
    V: tuple[NODE_LIKE],
    target: NODE_LIKE,
    graph: CausalGraph,
    *,
    mc_n: int = 1000,  # Monte Carlo samples
    **kwargs
) -> tuple[MeanStd, float]:
    # Preprocess the given sample and move to device
    if isinstance(x, np.ndarray):
        x = torch.Tensor(x.astype(float))

    x, n = graph._preprocess_x(x)
    x = {
        v: x.to(graph.device)
        for v, x in cast(dict[CausalNode, torch.Tensor], x).items()
    }

    # Transform V and y to CausalNode instances
    V = [graph[v] for v in V]  # transform to node instances
    target = cast(CausalNode, graph[target])

    # Define the value function
    def f(
        x: dict[CausalNode, torch.Tensor], subset: tuple[int, ...]
    ) -> np.ndarray:
        t1 = time()

        with torch.no_grad():
            # We'll generate N samples per x sample
            res = graph.sample(
                mc_n * n, target_node=target, interventions={
                    V[i]: x[V[i]].repeat(mc_n, 1)
                    for i in subset
                }
            ).view(mc_n, -1).mean(0).cpu().numpy()
        
        t2 = time()
        f.timer_nu += (t2 - t1)

        return res

    f.timer_nu = 0

    return shap(x, len(V), f, **kwargs), f.timer_nu


SEED_PYTHON = 123
SEED_NUMPY = 1234
SEED_TORCH = 12345
MIN_K = 5
MAX_K = 15
P = .25
RATIO_COALITIONS = .5
REPS = 30  # replications of the experiment

DATASET_SIZE = 1000  # dataset size
TRAIN_SIZE = .8
TEST_SIZE = .1

# DCG training
BATCH_SIZE = 100
LR = 1e-2
PATIENCE = 10

# SHAP
SHAP_N = 100


if __name__ == '__main__':
    # Set random state and seeds
    seed(SEED_PYTHON)
    rs = np.random.RandomState(seed=SEED_NUMPY)
    torch.random.manual_seed(SEED_TORCH)

    results: list[dict[str, Any]] = []
    for k in range(MIN_K, MAX_K + 1):
        nperms = max(30, cache_ratio_find_N(k, RATIO_COALITIONS))
        for rep in range(REPS):
            d: dict[str, Any] = dict(
                k=k, p=P, nperms=nperms, rep=rep
            )
            print(d)
            results.append(d)

            # Create graph
            graph = sample_dag_all_ancestors_not_all_parents(k, P)

            d['edges'] = ' '.join(
                str(node)
                for edge in graph.edges
                for node in edge
            )

            def to_node(i: int) -> str:
                return f'v{i}'

            # Create DGP
            dgp = DGP({
                to_node(i): list(map(to_node, graph.parents(i)))
                for i in graph.nodes
            })

            # Create train, val, test
            df = dgp(DATASET_SIZE, rs=rs)
            target = df.columns[-1]
            X = torch.Tensor(df.values)

            train_size = int(DATASET_SIZE * TRAIN_SIZE)
            test_size = int(DATASET_SIZE * TEST_SIZE)

            Xtrain = X[:train_size]
            Xval = X[train_size:-test_size]
            Xtest = X[-test_size:]

            idx = np.arange(len(df))
            idx_train = idx[:train_size]
            idx_val = idx[train_size:-test_size]
            idx_test = idx[-test_size:]

            # Train DCG
            device = torch.device('cpu')  # for Macbook; use get_device() if not
            definition = '\n'.join(
                f'{destination} norm 1 ' +
                ' '.join(sources)

                for destination, sources in dgp.G.items()
            )
            dcg = CausalGraph.from_definition(
                CausalGraph.parse_definition(
                    # Here we link variable types to their classes
                    definition,
                    norm=Normal,
                ),
                net_f=linear,
            ).warm_start(Xtrain).to(device)

            train(
                dcg,
                data_loader(
                    TensorDataset(Xtrain), BATCH_SIZE, drop_last=True
                ),
                data_loader(
                    TensorDataset(Xval), BATCH_SIZE, drop_last=False
                ),
                loss_f(ex_n=100),  # avg negative log-likelihood

                # Since we're using a linear model, this is enough
                optimizer=optim.Adam,
                optimizer_kwargs=dict(lr=LR),
                patience=PATIENCE,

                use_tqdm=True,  # don't show the progress bar
                silent=True
            )

            # DCG, without cache
            t1_dcg_no_cache = time()
            _, dcg_no_cache_time_nu = dcg_shap(
                df.iloc[idx_test].values,
                df.columns[:-1], df.columns[-1], dcg,
                mc_n=SHAP_N, rs=rs,
                max_perms=nperms,
                cache_size=0,  # don't use cache
                cache_redirect=lambda x: x,  # do nothing
            )
            t2_dcg_no_cache = time()

            d['dcg_no_cache_time'] = t2_dcg_no_cache - t1_dcg_no_cache
            d['dcg_no_cache_time_nu'] = dcg_no_cache_time_nu

            # DCG, with cache
            t1_dcg_cache = time()
            _, dcg_cache_time_nu = dcg_shap(
                df.iloc[idx_test].values,
                df.columns[:-1], df.columns[-1], dcg,
                mc_n=SHAP_N, rs=rs,
                max_perms=nperms,
                cache_size=None,  # no maxsize
                cache_redirect=_cache_redirect,  # sort tuple
            )
            t2_dcg_cache = time()

            d['dcg_cache_time'] = t2_dcg_cache - t1_dcg_cache
            d['dcg_cache_time_nu'] = dcg_cache_time_nu

            # DCG, with do-RA
            t1_dcg_reduction = time()
            fr2 = FR2(graph)
            _, dcg_reduction_time_nu = dcg_shap(
                df.iloc[idx_test].values,
                df.columns[:-1], df.columns[-1], dcg,
                mc_n=SHAP_N, rs=rs,
                max_perms=nperms,
                cache_size=None,  # no maxsize
                cache_redirect=fr2.cache_redirect,  # do-reduction
            )
            t2_dcg_reduction = time()

            d['dcg_reduction_time'] = t2_dcg_reduction - t1_dcg_reduction
            d['dcg_reduction_time_nu'] = dcg_reduction_time_nu

            d['frontiers'] = len(fr2.fr_cache)
            d['values'] = len(fr2.v_cache)
            d['total_values'] = nperms * (k + 1)

    with open(os.path.join(
        cast(str, os.getenv('RESULTS_DIR')),
        'fra',
        'time_ablation.yaml'
    ), 'w') as f:
        yaml.safe_dump(results, f)
