import numpy as np
import pandas as pd
import networkx as nx
import torch
import os
import sys

# Import from dowhy.gcm for DCM
from dowhy.gcm import (
    draw_samples,
    interventional_samples,
    counterfactual_samples,
    InvertibleStructuralCausalModel,
)

# Import DCM model creation utilities
# Add the DCM directory to the path if necessary
current_dir = os.path.dirname(os.path.realpath(__file__))
dcm_root = os.path.abspath(os.path.join(current_dir, os.pardir))
if dcm_root not in sys.path:
    sys.path.insert(0, dcm_root)

from model.diffusion import create_model_from_graph


class MyCausalMethod:
    def __init__(self):
        """
        A causal method that uses Diffusion-based Causal Models (DCM) to estimate
        causal queries (ATE and CTF-TE).
        """
        self.model = None
        self.params = {
            "num_epochs": 200,  # Can be adjusted depending on dataset complexity
            "lr": 1e-4,
            "batch_size": 64,
            "hidden_dim": 64,
            "use_positional_encoding": False,
            "weight_decay": 0,
            "lambda_loss": 0,
            "clip": False,
            "verbose": False,
        }

    def estimate(self, query, data, graph, index_to_variable):
        """
        Estimates the causal effect based on the provided query.

        Args:
            query: Query object from causal_profiler
            data: Dictionary mapping variable names to observed data arrays
            graph: Adjacency matrix of the causal graph
            index_to_variable: List mapping indices to variable names

        Returns:
            float: The estimated causal effect
        """
        try:
            # Convert graph and data to formats expected by DCM
            nx_graph = self._convert_to_nx_graph(graph, index_to_variable)
            df_data = self._convert_to_dataframe(data, index_to_variable)

            # Fit the model if not fitted already
            if self.model is None:
                self.model = create_model_from_graph(nx_graph, self.params)
                # Fit using dowhy's fit method
                from dowhy.gcm import fit

                fit(self.model, df_data)

            # Get query type
            query_type = (
                query.type.value if hasattr(query.type, "value") else str(query.type)
            )

            # Handle different query types
            if query_type == "ATE":
                return self._estimate_ate(query, index_to_variable)
            elif query_type == "Ctf-TE":
                return self._estimate_ctf_te(query, df_data, index_to_variable)
            else:
                print(f"Unsupported query type: {query_type}")
                return 0.0

        except Exception as e:
            print(f"Error in DCM estimation: {e}")
            import traceback

            traceback.print_exc()
            return 0.0

    def _estimate_ate(self, query, index_to_variable):
        """
        Estimate Average Treatment Effect (ATE).
        """
        # Extract treatment and outcome variables
        _, treatment_vars = query.get_intervened_info()
        _, outcome_vars = query.get_target_info()

        # Get treatment variable name
        treatment_var = treatment_vars[0].name

        # Get outcome variable name
        outcome_var = outcome_vars[0].name

        # Get treatment values
        t_values = query.vars_values.get("T", [1, 0])
        [t1_value], [t0_value] = t_values

        # Create intervention dictionaries
        treatment_intervention = {treatment_var: lambda x: float(t1_value)}
        control_intervention = {treatment_var: lambda x: float(t0_value)}

        # Get interventional samples
        num_samples = 1000
        treatment_samples = interventional_samples(
            self.model, treatment_intervention, num_samples_to_draw=num_samples
        )
        control_samples = interventional_samples(
            self.model, control_intervention, num_samples_to_draw=num_samples
        )

        # Calculate ATE
        treatment_mean = treatment_samples[outcome_var].mean()
        control_mean = control_samples[outcome_var].mean()

        return float(treatment_mean - control_mean)

    def _estimate_ctf_te(self, query, factual_data, index_to_variable):
        """
        Estimate Counterfactual Treatment Effect.
        """
        # Extract variables
        _, treatment_vars = query.get_intervened_info()
        _, outcome_vars = query.get_target_info()
        conditioning_label, conditioning_vars = query.get_conditioned_info()

        # Get variable names
        treatment_var = treatment_vars[0].name
        outcome_var = outcome_vars[0].name

        # Get treatment values
        t_values = query.vars_values.get("T", [1, 0])
        [t1_value], [t0_value] = t_values

        # Get the factual values that we're conditioning on
        factual_values = query.vars_values.get(conditioning_label, [])

        # Filter the factual data for records that match the conditioning values
        mask = np.ones(len(factual_data), dtype=bool)
        for i, cond_var in enumerate(conditioning_vars):
            var_name = cond_var.name
            cond_value = factual_values[i] if i < len(factual_values) else None
            if cond_value is not None:
                mask &= np.isclose(factual_data[var_name].values, cond_value, atol=1e-2)

        # If no matching records, return np.nan (failure)
        if not np.any(mask):
            return np.nan

        # Get the factual data that matches the conditions
        filtered_factual = factual_data[mask]

        # Create interventions
        t1_intervention = {treatment_var: lambda x: float(t1_value)}
        t0_intervention = {treatment_var: lambda x: float(t0_value)}

        # Get counterfactual samples
        cf_t1 = counterfactual_samples(
            self.model, t1_intervention, observed_data=filtered_factual
        )
        cf_t0 = counterfactual_samples(
            self.model, t0_intervention, observed_data=filtered_factual
        )

        # Calculate the counterfactual treatment effect
        # If outcome is a discrete variable, calculate probability
        # Otherwise, calculate mean difference
        y_cf_t1 = cf_t1[outcome_var].values
        y_cf_t0 = cf_t0[outcome_var].values

        # Check if we're calculating probabilities or expectations
        if "Y" in query.vars_values:
            # Discrete case - calculate probability of specific Y value
            y_value = query.vars_values["Y"]
            matches_t1 = np.isclose(y_cf_t1, y_value, atol=1e-3)
            matches_t0 = np.isclose(y_cf_t0, y_value, atol=1e-3)
            p_t1 = np.mean(matches_t1)
            p_t0 = np.mean(matches_t0)
            return float(p_t1 - p_t0)
        else:
            # Continuous case - calculate difference in expectations
            return float(np.mean(y_cf_t1) - np.mean(y_cf_t0))

    def _convert_to_nx_graph(self, graph, index_to_variable):
        """
        Convert the adjacency list to a NetworkX DiGraph.
        """
        if isinstance(graph, nx.DiGraph):
            return graph

        # Create a new DiGraph
        nx_graph = nx.DiGraph()

        # Add nodes
        nx_graph.add_nodes_from(index_to_variable)

        # Add edges from the adjacency matrix
        for node, neighbors in graph.items():
            for neighbor in neighbors:
                nx_graph.add_edge(index_to_variable[node], index_to_variable[neighbor])

        return nx_graph

    def _convert_to_dataframe(self, data, index_to_variable):
        """
        Convert data to a pandas DataFrame.
        """
        if isinstance(data, pd.DataFrame):
            return data

        return pd.DataFrame({k: v.flatten() for k, v in data.items()})
