import numpy as np
import torch
import sys
import os
from typing import Dict, Any, List, Union

# Try to import from the NCM module, handling potential import errors
try:
    # Add parent directory to the Python path (needed when run in the NCM/evaluation directory)
    current_dir = os.path.dirname(os.path.realpath(__file__))
    parent_dir = os.path.join(current_dir, os.pardir)
    if parent_dir not in sys.path:
        sys.path.insert(0, parent_dir)

    from src.ds import CausalGraph
    from src.scm import NCM
except ImportError as e:
    print(f"Error importing NCM modules: {e}")
    print("Make sure NCM is installed. You may need to run: pip install -e .")
    raise

# Import causal profiler
try:
    from causal_profiler import QueryType
except ImportError as e:
    print(f"Error importing causal_profiler: {e}")
    print("Make sure causal_profiler is installed from the causal-profiler repo")
    raise


class NCMMethod:
    def __init__(self, max_epochs=100, batch_size=1000, verbose=True):
        """
        Causal estimation using NCMs

        Args:
            max_epochs: Maximum number of training epochs
            batch_size: Batch size for NCM training
            verbose: Whether to print progress information
        """
        self.ncm = None
        self.cg = None
        self.index_to_variable = None
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.verbose = verbose

    def _init_graph_from_adj_list(self, adj_list, index_to_variable):
        """
        Initialize a CausalGraph from an adjacency list

        Args:
            adj_list: Dictionary mapping parent nodes to lists of child nodes
        """
        # Convert adjacency list to directed edges
        directed_edges = []
        for parent, children in adj_list.items():
            for child in children:
                directed_edges.append(
                    (index_to_variable[parent], index_to_variable[child])
                )

        if self.verbose:
            print(f"Directed edges: {directed_edges}")

        # Create the causal graph using NCM's CausalGraph class
        self.cg = CausalGraph(
            V=tuple(index_to_variable),
            directed_edges=directed_edges,
            bidirected_edges=[],  # Assuming no bidirected edges in adjacency list format
        )

    def estimate(self, query, data, graph, index_to_variable):
        """
        Estimate a causal query using Neural Causal Models (NCMs).

        Args:
            query: Query object from causal_profiler (contains query type, variables, values)
            data: Dictionary mapping variable names to observed data matrices
            graph: Adjacency matrix or adjacency list representing the causal graph
            index_to_variable: List mapping indices to variable names

        Returns:
            Estimated causal effect value
        """
        try:
            # Convert the graph to a CausalGraph object if not done already
            # You always want to train a new model in NCMs..
            if True or self.ncm is None or self.cg is None:
                # For this experiment assume it's adj lists
                self._init_graph_from_adj_list(graph, index_to_variable)

                # Create and train the NCM on the data
                self._fit_ncm(data)

            # Parse query details
            query_type = query.type

            if query_type == QueryType.ATE:
                # Extract treatment and outcome variables
                _, treatment_vars = query.get_intervened_info()
                _, outcome_vars = query.get_target_info()

                # Get treatment values (T1 and T0)
                treatment_values = query.vars_values.get("T", [1, 0])

                # Estimate ATE
                return self._estimate_ate(
                    treatment_vars, outcome_vars, treatment_values
                )

            else:
                print(f"Unsupported query type: {query_type}")
                return np.nan

        except Exception as e:
            print(f"Error in NCM estimation: {e}")
            return np.nan

    def _fit_ncm(self, data):
        """
        Fit a Neural Causal Model to observed data

        Args:
            data: Dictionary mapping variable names to data arrays
        """
        if self.verbose:
            print(f"Training NCM on {len(data)} variables...")

        torch_data = {}
        for var_name, var_data in data.items():
            if isinstance(var_data, np.ndarray):
                # If data is a numpy array, convert to tensor
                torch_data[var_name] = torch.tensor(var_data).float()
                if len(torch_data[var_name].shape) == 1:
                    # Add dimension if needed
                    torch_data[var_name] = torch_data[var_name].unsqueeze(1)
            else:
                # Already a tensor
                torch_data[var_name] = var_data

        # Create NCM with the causal graph
        self.ncm = NCM(self.cg)

        # Train the NCM
        optimizer = torch.optim.Adam(self.ncm.parameters(), lr=0.01)
        batch_size = self.batch_size

        # Training loop with early stopping
        best_loss = float("inf")
        patience_counter = 0
        patience = 5

        for epoch in range(self.max_epochs):
            try:
                optimizer.zero_grad()

                # Compute negative log likelihood loss
                loss = self.ncm.biased_nll(torch_data, n=batch_size)

                # Backpropagate and update parameters
                loss.backward()
                optimizer.step()

                # Track progress
                if self.verbose and epoch % 10 == 0:
                    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

                # Early stopping check
                if loss.item() < best_loss:
                    best_loss = loss.item()
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    if self.verbose:
                        print(f"Early stopping at epoch {epoch}")
                    break

            except Exception as e:
                print(f"Error during NCM training: {e}")
                break

        if self.verbose:
            print(f"NCM training completed with final loss: {best_loss:.4f}")

    def _estimate_ate(self, treatment_vars, outcome_vars, treatment_values):
        """
        Estimate Average Treatment Effect (ATE): P(Y|do(T=T1)) - P(Y|do(T=T0))

        This matches the approach used in experiment2.py where we use the NCM
        to simulate interventional distributions.

        Args:
            treatment_vars: List of treatment variables
            outcome_vars: List of outcome variables
            treatment_values: Tuple of (T1_value, T0_value)

        Returns:
            Estimated ATE
        """
        if not treatment_vars or not outcome_vars:
            return np.nan

        # Extract variable names (usually just one for each)
        treatment_var = treatment_vars[0].name
        outcome_var = outcome_vars[0].name

        # Get treatment values
        t1_value, t0_value = treatment_values

        # Number of samples for Monte Carlo estimation
        n_samples = 10000

        try:
            with torch.no_grad():
                # Create tensors of shape [n_samples, ...] filled with the scalar value
                sample_shape = list(self.ncm(n=1)[treatment_var].shape)[
                    1:
                ]  # e.g., [] or [d]
                t1_scalar = float(t1_value[0].item())
                t0_scalar = float(t0_value[0].item())
                do_t1 = {
                    treatment_var: torch.full([n_samples] + sample_shape, t1_scalar)
                }
                do_t0 = {
                    treatment_var: torch.full([n_samples] + sample_shape, t0_scalar)
                }

            # Estimate P(Y|do(T=T1))
            with torch.no_grad():
                samples_t1 = self.ncm(n=n_samples, do=do_t1)
            y_do_t1 = samples_t1[outcome_var].detach().numpy().mean()

            # Estimate P(Y|do(T=T0))
            with torch.no_grad():
                samples_t0 = self.ncm(n=n_samples, do=do_t0)
            y_do_t0 = samples_t0[outcome_var].detach().numpy().mean()

            # Compute ATE
            ate = y_do_t1 - y_do_t0

            if self.verbose:
                print(f"ATE({treatment_var} → {outcome_var}): {ate:.4f}")
                print(f"E[{outcome_var}|do({treatment_var}={t1_value})]: {y_do_t1:.4f}")
                print(f"E[{outcome_var}|do({treatment_var}={t0_value})]: {y_do_t0:.4f}")

            return ate

        except Exception as e:
            print(f"Error estimating ATE: {e}")
            return np.nan
