# CAREFL model class
import os
import time
from typing import Any, Optional, Tuple

import networkx as nx
import numpy as np
import pandas as pd
import torch
import wandb
from numpy.typing import NDArray

import custom_models.cnf_src.causal_nf.config as causal_nf_config
import custom_models.cnf_src.causal_nf.utils.io as causal_nf_io
import custom_models.cnf_src.causal_nf.utils.training as causal_nf_train
import custom_models.cnf_src.causal_nf.utils.wandb_local as wandb_local
from custom_models.cnf_src.causal_nf.config import cfg
from custom_models.cnf_src.causal_nf.preparators.custom_preparator import (
    CustomPreparator,
)
from custom_models.CustomCausalModel import CustomCausalModel


class CAREFL(CustomCausalModel):
    def __init__(self, causal_graph: nx.DiGraph):
        self.causal_graph = causal_graph
        return

    def identify_effect(
        self,
        treatment: Optional[dict[str, float]] = {},
        outcome: Optional[dict[str, float]] = {},
        obs_data: Optional[pd.DataFrame] = None,
        int_data: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
    ) -> dict[str, Any]:
        return

    def fit(
        self,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        outcome: Optional[str] = None,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
    ) -> dict[str, Any]:
        """
        Train NCM
        """

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        torch.manual_seed(seed)

        # Load configs
        default_config_file = os.path.join(
            os.getcwd(),
            "src",
            "custom_models",
            "cnf_src",
            "causal_nf",
            "configs",
            "default_config.yaml",
        )
        config_file = os.path.join(
            os.getcwd(),
            "src",
            "custom_models",
            "cnf_src",
            "causal_nf",
            "configs",
            "carefl.yaml",
        )

        config = causal_nf_config.build_config(
            config_file=config_file,
            config_default_file=default_config_file,
        )
        # Validate config
        causal_nf_config.assert_cfg_and_config(cfg, config)

        # Set device correctly
        if cfg.device in ["cpu"]:
            os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

        # Set seed
        causal_nf_train.set_reproducibility(cfg, seed)

        # Load and prepare data
        self.preparator = CustomPreparator.loader(data, self.causal_graph, cfg.dataset, device = device)

        loaders = self.preparator.get_dataloaders(
            batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers
        )

        for i, loader in enumerate(loaders):
            causal_nf_io.print_info(f"[{i}] num_batches: {len(loader)}")

        self.model = causal_nf_train.load_model(
            cfg=cfg, preparator=self.preparator
        ).to(device)

        param_count = self.model.param_count()
        config["param_count"] = param_count

        run = wandb.init(
            mode="offline",
            group=None,
            project="prova",
            config=config,
        )

        import uuid

        run_uuid = run.id

        # # # Here you can add many features in your Trainer: such as num_epochs,  gpus used, clusters used etc.

        dirpath = os.path.join(cfg.root_dir, run_uuid)
        logger_dir = os.path.join(cfg.root_dir, run_uuid)

        trainer, logger = causal_nf_train.load_trainer(
            cfg=cfg,
            dirpath=dirpath,
            logger_dir=logger_dir,
            include_logger=True,
            model_checkpoint=cfg.train.model_checkpoint,
            cfg_early=cfg.early_stopping,
            preparator=self.preparator,
        )

        causal_nf_io.print_info(f"Experiment folder: {logger.save_dir}\n\n")

        wandb_local.log_config(dict(config), root=logger.save_dir)

        wandb_local.copy_config(
            config_default=causal_nf_config.DEFAULT_CONFIG_FILE,
            config_experiment=config_file,
            root=logger.save_dir,
        )
        train_start_time = time.process_time()
        trainer.fit(
            self.model, train_dataloaders=loaders[0], val_dataloaders=loaders[1]
        )
        delta_train_time = time.process_time() - train_start_time

        print("Training finished")

        # Save train_loss
        train_info_df = pd.DataFrame(self.model.train_step_outputs)
        train_info_df["epoch"] = train_info_df.index + 1
        train_info_df.set_index("epoch", inplace=True)
        train_info_df.to_csv(os.path.join(save_dir, "train_info.csv"))

        runtime = {
            "Training Time": delta_train_time,
        }
        if "gpu_mem" in self.model.train_step_outputs[0]:
            runtime["Avg. GPU Memory"] = max(
                d["gpu_mem"] for d in self.model.train_step_outputs
            )

        return runtime

    def estimate_effect(
        self,
        outcome: str,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        # Convert the treatment and evidence dictionaries to the format expected by the do-calculus
        treatment_dict, control_dict = self.extract_treat_control(treatment)

        # Move everything to CPU
        self.model.to("cpu")
        self.model.input_scaler.to("cpu")
        self.preparator.scaler_transform.shift = self.preparator.scaler_transform.shift.to(
            "cpu"
        )
        self.preparator.scaler_transform.scale = self.preparator.scaler_transform.scale.to(
            "cpu"
        )
        
        ### ATE ###
        ate, delta_estimate_time_ate, int_distr_treated, state_names, treated_samples, int_distr_control, control_samples = (
            self.get_average_effect(
                treatment=treatment_dict,
                control=control_dict,
                outcome=outcome,
                evidence={},
            )
        )

        ### CATE ###
        # Sometimes the evidence variable is discrete, but the model may output continuous values.
        # In that case, we need to bin the continuous values to get the probability distribution and condition it, or empirically approximate it.

        cate, delta_estimate_time_cate, int_distr_cate, _, conditional_treated_samples, cond_int_distr_control, conditional_control_samples = self.get_average_effect(
            treatment=treatment_dict,
            control=control_dict,
            outcome=outcome,
            evidence=evidence,
            data=data,
        )

        # Save results
        results = {
            "target": outcome,
            "state_names": list(state_names),
            "Interventional Distribution": int_distr_treated,
            "Conditional Interventional Distribution": int_distr_cate,
            "ATE": ate,
            "evidence": evidence if evidence != {} else None,
            "CATE": cate if evidence != {} else None,
            "Interventional Samples": treated_samples,
            "Conditional Interventional Samples": conditional_treated_samples,
            "Control Samples": control_samples,
            "Conditional Control Samples": conditional_control_samples,
            "Control Interventional Distribution": int_distr_control,
            "Conditional Control Distribution": cond_int_distr_control
        }

        runtime = {
            "Estimation Time ATE": delta_estimate_time_ate,
            "Estimation Time CATE": delta_estimate_time_cate,
        }

        return results, runtime

    def get_average_effect(
        self, treatment, control, outcome, evidence, data=None
    ) -> float:
        """
        Calculate the average effect.

        Parameters:
        untreated_avg (float): The average of the untreated samples.
        treated_avg (float): The average of the treated samples.

        Returns:
        float: The average effect.

        Raises:
        None

        """

        time_start = time.process_time()
        n = len(list(self.causal_graph.nodes()))
        # Samples with treatment (Treatment)
        treated_samples = self.model.predict(
            (torch.zeros((10000, n)), torch.zeros((10000, n))),
            intervene=True,
            int_dict=treatment,
        )["int_samples"]
        # Samples without treatment (Control)
        control_samples = self.model.predict(
            (torch.zeros((10000, n)), torch.zeros((10000, n))),
            intervene=True,
            int_dict=control,
        )["int_samples"]

        delta_estimate_time = time.process_time() - time_start

        control_samples = pd.DataFrame(control_samples.cpu(), columns=self.model.preparator.dataset.labels)
        treated_samples = pd.DataFrame(treated_samples.cpu(), columns=self.model.preparator.dataset.labels)

        # Conditional and quantize the samples
        treated_samples, control_samples = self.condition_and_quantize(treated_samples, control_samples, evidence, data = data, quantize=True)
        
        full_treated_samples = treated_samples.copy()
        full_control_samples = control_samples.copy()

        # Convert samples to numpy arrays
        control_samples = control_samples[outcome]
        treated_samples = treated_samples[outcome]

        # Get distribution from samples
        control_samples = np.where(control_samples > 0, 1, -1)
        state_names, control_distr, bins = self.get_probability_distribution(
            control_samples, [-1, 1]
        )
        treated_samples = np.where(treated_samples > 0, 1, -1)
        state_names_2, treated_distr, bins2 = self.get_probability_distribution(
            treated_samples, [-1, 1]
        )


        # Calculate the average effect
        control_avg = np.mean(control_samples)
        treated_avg = np.mean(treated_samples)
        average_effect = treated_avg - control_avg

        return average_effect, delta_estimate_time, treated_distr, state_names, full_treated_samples, control_distr, full_control_samples


