import numpy as np
import pandas as pd
import os
import torch
import glob
import time
import uuid
import json
from typing import Any, Dict
from sklearn.kernel_ridge import KernelRidge
from dodiscover.ci import KernelCITest
from dodiscover.score.ges import GES # First, run: cd ~/dodiscover; git checkout causal-benchmark
from dodiscover.constraint.pcalg import PC
from dodiscover.toporder.cam import CAM
from dodiscover.toporder.score import SCORE
from dodiscover.toporder.nogam import NoGAM
from dodiscover.toporder.das import DAS
from dodiscover.toporder.diffan import DiffAN
from dodiscover.continuous.grandag import GranDAG
from dodiscover.toporder.resit import RESIT
from dodiscover.toporder.random import RandomGuessing
from dodiscover.toporder.varsort import VarSort
from dodiscover.toporder.scoresort import ScoreSort
from dodiscover.constraint.fcialg import FCI
from lingam import DirectLiNGAM
from dodiscover import make_context
from utils._utils import pywhy_cpdag_to_numpy, pywhy_dag_to_numpy, pywhy_pag_to_numpy, is_cpdag, ConsoleManager

# warnings.filterwarnings("error")

class ExperimentManager:
    def __init__(
        self,
        method_name : str,
        method_parameters : Dict[str, Any],
        dataset_id : int, 
        data_dir : str,
        output_dir : str,
        noise_distr : str,
        task : str
    ):
        """Manager for experiments on a single daatset. 

        Parameters
        ----------
        method_name : str
            Name of the method for th inference
        method_parameters : List[Any]
            List with the parameters to instantiate the inference method.
            len(List) different with respect to the method.
        dataset_id : int
            Id of the dataset from 1 to number_of_seeds
        data_dir : str
            Directory containing datasets and related groundtruths.
            Data and grountruth matrix can be read with np.genfromtxt
            Example directory: efs/data/<inference or hyperparameters>/<scenario>/<dimension_density_samples>
        output_dir : str
            Directory for output storage. Specific to scenario and method
            - E.g. inference directory:
            efs/tmp/causal-benchmark-logs/inference/<graph_algorithm>/<scenario>/<scenario_param>/<method_name>/100_small_sparse
        noise_distr : str
            Distribution of the noise terms. Can be either Gaussian or Random (i.e. Non-Gaussian)
        task : str
            Defines th whether the experiment is for parameters tuning or causal graph inference
        """
        # TODO: check if seed is needed
        # np.random.seed(seed)

        # Attributes
        self.dataset_id = dataset_id
        self.data_dir = data_dir
        self.method_name = method_name
        self.method_parameters = method_parameters # Use in logging
        self.num_samples, self.graph_size, self.graph_density = data_dir.split("/")[-1].split("_")
        self.noise_distr = noise_distr
        self.output_dir = output_dir
        self.task=task
        self.algorithm = self.get_method_instance(method_name, method_parameters, task)

        self.init_logs(task)
    
    def init_logs(self, task):
        """Create logs folder based on the speciic task. 
        - If task == 'inference', check if output_dir/predictions directory exists: if not, the directory is created
        Then, check if output_dir/tmp directory exists: if not directory is created.
        In tmp directory are stored the metadata of a single experiments (config and results)
        - If task == 'tuning', check if output_dir exists, else create it 

        Parameters
        ----------
        task : str
            'tuning' or 'inference'. If task == inference, initialize 
            self.output_dir/predictions and self.output_dir/tmp subfolders.
            Else initialize self.output_dir
        """
        if not os.path.exists(self.output_dir):
            try:
                os.makedirs(self.output_dir)
            except FileExistsError:
                pass   

        if task == "inference" or task == "standardized":
            pred_output_dir = os.path.join(self.output_dir, "predictions") # NOTE: when predictions is modified, also modify delete_past_predictions() in datase_inference.py
            tmp_logs_dir = os.path.join(self.output_dir, "tmp")
            if not os.path.exists(pred_output_dir):
                try:
                    os.mkdir(pred_output_dir)
                except FileExistsError:
                    pass        
            if not os.path.exists(tmp_logs_dir):
                try:
                    os.mkdir(tmp_logs_dir)
                except FileExistsError:
                    pass          
    
    @property
    def pred_output_dir(self):
        return os.path.join(self.output_dir, "predictions")
    
    @property
    def tmp_logs_dir(self):
        return os.path.join(self.output_dir, "tmp")

    def has_order(self, method):
        if method in ["das", "score", "nogam", "cam", "diffan", "grandag", "lingam", "resit", "random", "varsort", "scoresort"]:
            return True
        return False
    
    def tune(
        self,
        data_df : pd.DataFrame,
        params_id : str
    ):
        start = time.time()
        X = torch.tensor(data_df.to_numpy())
        if self.method_name == "diffan":
            # self.algorithm.fit(data_df, context)
            X = X.float()
            self.algorithm.trainer.setup(X)
            self.algorithm.trainer.fit()
            self.algorithm.trainer.fit()
        elif self.method_name == "grandag":
            self.algorithm.init_model(X.size(1))
            self.algorithm._prepare_data(X)
            self.algorithm.train()
        execution_time = round(time.time() - start, 2) 
        best_score = self.algorithm.val_score
        
        # Store result in the json
        pred_id = uuid.uuid1().hex + ".json" # unique ID for the predictions
        logs_dict = dict()
        with open(os.path.join(self.output_dir, pred_id), "w+") as f:
            logs_dict = {"id" : params_id, "val_score" : best_score, "time" : execution_time}
            json.dump(logs_dict, f)

    def inference(
        self,
        data_df : pd.DataFrame
    ):
        """Each ExperimentManager subclass implement its run_scenario() method.
        Algorithms are instantiated with the required hyperparameters.
        Leave implementation to subclasses such that they can independently 
        decide to optimize or not hyperparameters for their environment

        Parameters
        ----------
        data_df : pd.DataFrame
            Input dataset for the inference
        A_truth : np.array
            DAG ground truth.
        """
        context = make_context().variables(observed=data_df.columns).build()
        start = time.time()
        toporder = None
        if self.method_name == "lingam":
            self.algorithm.fit(data_df.to_numpy())
            W = self.algorithm.adjacency_matrix_
            A_pred = np.transpose((np.abs(W) > 0).astype(int), (1, 0))
            toporder = self.algorithm.causal_order_
        elif self.method_name == "fci":
            self.algorithm.fit(data_df, context)
            G = self.algorithm.graph_
            A_pred = pywhy_pag_to_numpy(G)
        else:
            self.algorithm.fit(data_df, context)
            A_pred = pywhy_cpdag_to_numpy(self.algorithm.graph_) if is_cpdag(self.method_name) else pywhy_dag_to_numpy(self.algorithm.graph_)
            if self.has_order(self.method_name):
                toporder = self.algorithm.order_
        execution_time=np.round(time.time() - start, 2)
        num_nodes = A_pred.shape[0]

        pred_id = uuid.uuid1().hex + ".csv" # unique ID for the predictions
        np.savetxt(os.path.join(self.pred_output_dir, pred_id), A_pred, delimiter=",")
        
        # Handling of more than one tunable parameter
        metadata_output_file = os.path.join(self.tmp_logs_dir, uuid.uuid1().hex + ".json")
        ConsoleManager.metadata_storing(metadata_output_file)

        fields = self.log_fields()
        run_log = dict.fromkeys(fields)
        # "seed_id","noise","samples","size","density","num_nodes","hyperparameters","time","pred_location","gt_location"
        run_log["seed_id"] = self.dataset_id
        run_log["noise"] = self.noise_distr
        run_log["samples"] = self.num_samples
        run_log["size"] = self.graph_size
        run_log["density"] = self.graph_density
        run_log["num_nodes"] = num_nodes
        run_log["hyperparameters"] = self.method_parameters
        run_log["time"] = execution_time
        run_log["pred_location"] = os.path.join(self.pred_output_dir, pred_id)
        run_log["gt_location"] = os.path.join(self.data_dir, f"groundtruth_{self.dataset_id}.csv")
        if toporder is not None:
            toporder = [int(x) for x in toporder]
            run_log["order"] = toporder
        with open(metadata_output_file, "w") as f:
            json.dump(run_log, f)
        ConsoleManager.done_msg()
     

    # ------------------ Utils ------------------
    def log_fields(self):
        return ["seed_id","noise","samples","size","density","num_nodes","hyperparameters","time","pred_location","gt_location"]
    
    def clear_folder(self, dir):
        files_to_remove = glob.glob(os.path.join(dir, "*.csv"))
        for f in files_to_remove:
            os.remove(f)

    def get_data(self, debug=False):
        """Get data from self.data_dir, according to self.dataset_id.
        Usually, data are stored in /efs/data/<scenario>/<config>

        Parameters
        ----------
        debug : bool
            If True, avoid prepending "/scratch" to self.data_dir
        
        Return
        ------
        data_df : pd.DataFrame
            pandas dataframe of the input data
        A_truth : np.array
            Ground truth adjacency matrix
        """
        if debug:
            data = np.genfromtxt(os.path.join(self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=",")
            # data_df = pd.DataFrame(np.genfromtxt(os.path.join(self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=","))
        else:
            try:
                data = np.genfromtxt(os.path.join("/scratch", self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=",")
                # data_df = pd.DataFrame(np.genfromtxt(os.path.join("/scratch", self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=",")) 
            except:
                data = np.genfromtxt(os.path.join(self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=",")
                # data_df = pd.DataFrame(np.genfromtxt(os.path.join(self.data_dir, f"data_{self.dataset_id}.csv"), delimiter=",")) 

        if self.task == "standardized":
            data = data / data.std(axis=0)

        data_df = pd.DataFrame(data)
        return data_df


    def check_valid_method(self):
        accepted_methods = ["ges", "lingam", "pc", "cam", "score", "das", "nogam", "grandag", "diffan", "resit", "random"]
        for method in self.methods_names:
            assert method in accepted_methods, f"The inference algorithm {method} is not known.\n Please provide an algorithm in the list {accepted_methods}"


    # For Slurm inference
    def get_method_instance(
        self, 
        method_name : str,
        params : Dict[str, Any],
        task : str
    ):
        if method_name == "ges":
            lambda_value = params["lambda"]
            return GES(score="bic", iterate=False, parameters={"lambda": lambda_value, "method": "scatter"})
        elif method_name == "pc":
            alpha_value = params["alpha"]
            ci_estimator = KernelCITest() 
            return PC(ci_estimator, alpha=alpha_value)
        elif method_name == "score":
            alpha_value = params["alpha"]
            return SCORE(cam_cutoff=alpha_value, pns_num_neighbors=20)
        elif method_name == "cam":
            alpha_value = params["alpha"]
            return CAM(cam_cutoff=alpha_value, pns_num_neighbors=20)
        elif method_name == "nogam":
            alpha_value = params["alpha"]
            return NoGAM(cam_cutoff=alpha_value, pns_num_neighbors=20)
        elif method_name == "das":
            alpha_value = params["alpha"] 
            return DAS(das_cutoff=alpha_value, cam_cutoff=alpha_value)
        elif method_name == "lingam":
            return DirectLiNGAM()
        elif method_name == "fci":
            alpha_value = params["alpha"]
            ci_estimator = KernelCITest()
            return FCI(ci_estimator, alpha=alpha_value)
        elif method_name == "resit":
            alpha_value = params["alpha"] 
            regressor = KernelRidge(kernel='rbf', gamma=0.1, alpha=0.01)
            return RESIT(regressor, alpha=alpha_value)
        elif method_name == "random":
            return RandomGuessing()
        elif method_name == "varsort":
            return VarSort()
        elif method_name == "scoresort":
            return ScoreSort()
        elif method_name == "diffan":
            alpha_value = params["alpha"]
            learning_rate = params["lr"]
            batch_size = params["bs"]
            if task == "tuning":
                instance = DiffAN(cam_cutoff=alpha_value, lr=learning_rate, mbs=batch_size, eval_mbs=256, cam_pruning=False, pns=False)
            else:
                instance = DiffAN(cam_cutoff=alpha_value, lr=learning_rate, mbs=batch_size, eval_mbs=128, pns_num_neighbors=20)
            return instance
        elif method_name == "grandag":
            alpha_value = params["alpha"]
            learning_rate = params["lr"]
            batch_size = params["bs"]
            if task == "tuning":
                instance = GranDAG(cam_cutoff=alpha_value, lr=learning_rate, mbs=batch_size, cam_pruning=False, pns=False, num_train_iter=50000)
            else:
                instance = GranDAG(cam_cutoff=alpha_value, lr=learning_rate, mbs=batch_size, pns_num_neighbors=20)
            return instance
        else:
            raise ValueError("Unknown input method")