import os
import numpy as np
import torch
import pytorch_lightning as pl
import networkx as nx
from typing import Dict, List, Any, Tuple, Union
import sys
import json

# Add parent directory to import the VACA modules
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Import VACA modules
from models.vaca.vaca import VACA
from utils.distributions import Normal
from utils.likelihoods import get_likelihood
from utils.constants import Cte


# Simple Data class to mimic PyTorch Geometric's Data object
class SimpleData:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

        # Add empty attributes that might be expected by VACA
        if not hasattr(self, 'num_graphs'):
            self.num_graphs = self.x.size(0)

    def to(self, device):
        """Move all tensors to the specified device"""
        for key, value in self.__dict__.items():
            if isinstance(value, torch.Tensor):
                setattr(self, key, value.to(device))
        return self

    def clone(self):
        """Create a shallow copy with cloned tensors"""
        kwargs = {}
        for key, value in self.__dict__.items():
            if isinstance(value, torch.Tensor):
                kwargs[key] = value.clone()
            else:
                kwargs[key] = value
        return SimpleData(**kwargs)


class MyCausalMethod:
    def __init__(self, seed=42):
        """
        Initialize the causal method using VACA.
        Uses VACA to learn causal relationships directly from data.
        """
        # Set seed for reproducibility
        pl.seed_everything(seed)
        self.model = None
        self.variable_type = None  # Discrete or continuous
        self.nodes_list = None
        self.X = None  # Store data
        self.graph_edges = None  # Store graph edges

    def _learn_model(self, data, graph, variable_type, index_to_variable):
        """
        Learn a VACA model directly from the data and graph.

        Args:
            data: Dictionary with variable names as keys and their values as arrays
            graph: Dictionary representing the causal graph structure
            variable_type: Either 'CONTINUOUS' or 'DISCRETE'
            index_to_variable: Array of variable names
        """
        self.variable_type = variable_type
        self.nodes_list = list(data.keys())
        num_nodes = len(self.nodes_list)

        # Extract the data into a tensor
        # Convert each variable's data to torch tensor if needed
        tensors = []
        for node in self.nodes_list:
            if isinstance(data[node], np.ndarray):
                tensors.append(
                    torch.tensor(data[node], dtype=torch.float32).unsqueeze(1)
                )
            else:
                tensors.append(data[node].unsqueeze(1))

        # Concatenate all variables into a single tensor
        self.X = torch.cat(tensors, dim=1)

        # # For discrete data, ensure values are integers and in the correct range
        # if variable_type == "DISCRETE":
        #     # Convert to long type for categorical distributions
        #     self.X = self.X.long()
        #     # Ensure values start from 0 for categorical distributions
        #     for i in range(self.X.size(1)):
        #         unique_vals = torch.unique(self.X[:, i])
        #         val_map = {val.item(): idx for idx, val in enumerate(unique_vals)}
        #         mapped_vals = torch.tensor([val_map[val.item()] for val in self.X[:, i]], dtype=torch.long)
        #         self.X[:, i] = mapped_vals.unsqueeze(1)  # Add dimension to match self.X shape


        # Create edge index from graph
        edge_list = []
        for source, targets in graph.items():
            for target in targets:
                edge_list.append([source, target])

        if edge_list:
            self.graph_edges = (
                torch.tensor(edge_list, dtype=torch.long).t().contiguous()
            )
        else:
            # Empty graph
            self.graph_edges = torch.zeros((2, 0), dtype=torch.long)

        # Infer the degrees for the GNN
        in_degree = torch.zeros(num_nodes, dtype=torch.long)
        for i in range(num_nodes):
            # Count incoming edges
            in_degree[i] = sum(1 for edge in edge_list if edge[1] == i)

        # Determine likelihood list
        likelihood_type = None
        if variable_type == "DISCRETE":
            likelihood_type = Cte.CATEGORICAL # TODO: need to specify domain_size on get_likelihood
        else:
            likelihood_type = Cte.GAUSSIAN

        # Create a proper likelihood object
        likelihood = get_likelihood(likelihood_type, num_nodes)

        # Simple normalization for the data
        class SimpleScaler:
            def __init__(self, data=None):
                # Calculate mean and std for each column for normalization
                self.mean = (
                    torch.mean(data, dim=0, keepdim=True) if data is not None else 0
                )
                self.std = (
                    torch.std(data, dim=0, keepdim=True) if data is not None else 1
                )
                self.std[self.std == 0] = 1  # Avoid division by zero

            def transform(self, x):
                return (x - self.mean) / self.std

            def inverse_transform(self, x):
                return x * self.std + self.mean

        scaler = SimpleScaler(self.X) if variable_type == "CONTINUOUS" else None

        # Create and train the VACA model
        model_params = {
            "is_heterogeneous": False,
            "likelihood_x": likelihood,
            "deg": in_degree,
            "num_nodes": num_nodes,
            "edge_dim": 1,  # Simple edge features
            "scaler": scaler,
            "z_dim": 4,  # Using latent_dim value
            "h_dim_list_enc": [16, 16],  # Using gnn_hidden_dims value
            "h_dim_list_dec": [16, 16],  # Using decoder_hidden_dims value
            "beta": 0.1,  # Using lambda_KL value
            "architecture": "pna",
        }

        self.model = VACA(**model_params)

        # Train the model directly using the data
        self._train_model_directly()

        return self.model

    def _train_model_directly(self, num_epochs=20):
        """
        Train the model directly on the data without using a dataloader.
        """
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        batch_size = min(128, len(self.X))
        edge_attr = (
            torch.ones(self.graph_edges.size(1), 1)
            if self.graph_edges.size(1) > 0
            else torch.zeros(0, 1)
        )
        node_ids = torch.arange(len(self.nodes_list))

        self.model.train()
        for epoch in range(num_epochs):
            total_loss = 0
            num_batches = 0

            # Process data in batches
            indices = torch.randperm(len(self.X))
            for i in range(0, len(self.X), batch_size):
                batch_indices = indices[i : i + batch_size]

                # Create batch
                x_batch = self.X[batch_indices].squeeze(-1)
                # Custom forward and backward pass
                # Create data object that mimics PyTorch Geometric's Data
                batch_data = SimpleData(
                    x=x_batch,
                    edge_index=self.graph_edges,
                    edge_attr=edge_attr,
                    node_ids=node_ids,
                    batch=torch.zeros(len(batch_indices), dtype=torch.long),
                    num_graphs=len(batch_indices)
                )

                # Forward pass using model directly
                objective, _ = self.model(batch_data, estimator=self.model.estimator, beta=self.model.beta)
                loss = -objective  # Negative because we want to maximize the objective

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

            # Step scheduler
            scheduler.step()

            # Print progress
            if (epoch + 1) % 5 == 0:
                print(
                    f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/num_batches:.4f}"
                )

        # Set to evaluation mode
        self.model.eval()
        self.model.freeze()

    def _create_intervention_data(self, intervention_idx, intervention_value):
        """
        Helper method to create a SimpleData object with intervention parameters.

        Args:
            intervention_idx: Index of the variable to intervene on
            intervention_value: Value to set for the intervention

        Returns:
            SimpleData object with intervention parameters
        """
        # Create intervention mask and values
        intervention_mask = torch.zeros(1, len(self.nodes_list), dtype=torch.bool)
        intervention_mask[0, intervention_idx] = True

        intervention_values = torch.zeros(1, len(self.nodes_list))
        intervention_values[0, intervention_idx] = intervention_value

        # Create data object with actual data
        return SimpleData(
            x=self.X[0:1],  # Use actual data from the first sample, used for size purposes only
            edge_index=self.graph_edges,
            edge_attr=torch.ones(self.graph_edges.size(1), 1) if self.graph_edges.size(1) > 0 else torch.zeros(0, 1),
            intervention_mask=intervention_mask,
            intervention_values=intervention_values,
            node_ids=torch.arange(len(self.nodes_list)),
            num_nodes=len(self.nodes_list)
        )

    def _estimate_ATE(self, query):
        """
        Estimate Average Treatment Effect (ATE) using direct computation.
        ATE = E[Y | do(T=T1)] - E[Y | do(T=T0)]
        """
        # Get treatment variable and values
        T_var = query.vars["T"][0]
        T_name = T_var.name
        if T_name not in self.nodes_list:
            # Map variable name to one in our list if needed
            T_idx = int(T_name.split("_")[-1]) if "_" in T_name else -1
            if 0 <= T_idx < len(self.nodes_list):
                T_name = self.nodes_list[T_idx]
            else:
                # Try to find a matching variable
                for node in self.nodes_list:
                    if T_name.lower() in node.lower():
                        T_name = node
                        break

        T_idx = self.nodes_list.index(T_name)
        T_values = query.vars_values["T"]
        [T1_value], [T0_value] = T_values
        T1_value = float(T1_value)
        T0_value = float(T0_value)

        # Get outcome variable
        Y_var = query.vars["Y"][0]
        Y_name = Y_var.name
        if Y_name not in self.nodes_list:
            # Map variable name to one in our list if needed
            Y_idx = int(Y_name.split("_")[-1]) if "_" in Y_name else -1
            if 0 <= Y_idx < len(self.nodes_list):
                Y_name = self.nodes_list[Y_idx]
            else:
                # Try to find a matching variable
                for node in self.nodes_list:
                    if Y_name.lower() in node.lower():
                        Y_name = node
                        break

        Y_idx = self.nodes_list.index(Y_name)

        # Generate samples with T=T1
        num_samples = 1000
        with torch.no_grad():
            # Create intervention data for T=T1
            data_t1 = self._create_intervention_data(T_idx, T1_value)

            # Generate samples with T=T1
            samples_t1 = []
            for _ in range(num_samples):
                # Sample from prior
                z = torch.randn(1, len(self.nodes_list) * self.model.z_dim)

                # Forward pass with intervention
                x_hat, _ = self.model.model.decoder(
                    z,
                    data_t1.edge_index,
                    edge_attr=data_t1.edge_attr,
                    intervention_mask=data_t1.intervention_mask,
                    intervention_values=data_t1.intervention_values,
                    return_type="sample",
                    node_ids=data_t1.node_ids,
                )

                samples_t1.append(x_hat)

            # Stack samples
            samples_t1 = torch.cat(samples_t1, dim=0)

            # Create intervention data for T=T0
            data_t0 = self._create_intervention_data(T_idx, T0_value)

            # Generate samples with T=T0
            samples_t0 = []
            for _ in range(num_samples):
                # Sample from prior
                z = torch.randn(1, len(self.nodes_list) * self.model.z_dim)

                # Forward pass with intervention
                x_hat, _ = self.model.model.decoder(
                    z,
                    data_t0.edge_index,
                    edge_attr=data_t0.edge_attr,
                    intervention_mask=data_t0.intervention_mask,
                    intervention_values=data_t0.intervention_values,
                    return_type="sample",
                    node_ids=data_t0.node_ids,
                )

                samples_t0.append(x_hat)

            # Stack samples
            samples_t0 = torch.cat(samples_t0, dim=0)

            # Calculate ATE as E[Y|do(T=T1)] - E[Y|do(T=T0)]
            ate = torch.mean(samples_t1[:, Y_idx]) - torch.mean(samples_t0[:, Y_idx])

            return ate.item()

    def _estimate_CTF_TE(self, query):
        """
        Estimate Counterfactual Total Effect (CTF-TE).
        P(Y=y_{do(T=t1)} | V_F=v_F) - P(Y=y_{do(T=t0)} | V_F=v_F)
        """
        # Get treatment variable and values
        T_var = query.vars["T"][0]
        T_name = T_var.name
        if T_name not in self.nodes_list:
            # Map variable name to one in our list if needed
            T_idx = int(T_name.split("_")[-1]) if "_" in T_name else -1
            if 0 <= T_idx < len(self.nodes_list):
                T_name = self.nodes_list[T_idx]
            else:
                # Try to find a matching variable
                for node in self.nodes_list:
                    if T_name.lower() in node.lower():
                        T_name = node
                        break

        T_idx = self.nodes_list.index(T_name)
        T_values = query.vars_values["T"]
        [T1_value], [T0_value] = T_values
        T1_value = float(T1_value)
        T0_value = float(T0_value)

        # Get outcome variable and target value
        Y_var = query.vars["Y"][0]
        Y_name = Y_var.name
        if Y_name not in self.nodes_list:
            # Map variable name to one in our list if needed
            Y_idx = int(Y_name.split("_")[-1]) if "_" in Y_name else -1
            if 0 <= Y_idx < len(self.nodes_list):
                Y_name = self.nodes_list[Y_idx]
            else:
                # Try to find a matching variable
                for node in self.nodes_list:
                    if Y_name.lower() in node.lower():
                        Y_name = node
                        break

        Y_idx = self.nodes_list.index(Y_name)
        Y_value = query.vars_values["Y"][0]

        # Get factual conditioning variable and value
        V_F_var = query.vars["V_F"][0]
        V_F_name = V_F_var.name
        if V_F_name not in self.nodes_list:
            # Map variable name to one in our list if needed
            V_F_idx = int(V_F_name.split("_")[-1]) if "_" in V_F_name else -1
            if 0 <= V_F_idx < len(self.nodes_list):
                V_F_name = self.nodes_list[V_F_idx]
            else:
                # Try to find a matching variable
                for node in self.nodes_list:
                    if V_F_name.lower() in node.lower():
                        V_F_name = node
                        break

        V_F_idx = self.nodes_list.index(V_F_name)
        V_F_value = query.vars_values["V_F"][0]

        # Find samples where V_F ≈ V_F_value
        matches = []
        if self.variable_type == "DISCRETE":
            # For discrete variables, exact matching
            for i in range(len(self.X)):
                if abs(self.X[i, V_F_idx] - V_F_value) < 1e-6:
                    matches.append(i)
        else:
            # For continuous variables, find values within a small epsilon
            epsilon = 0.1
            for i in range(len(self.X)):
                if abs(self.X[i, V_F_idx] - V_F_value) < epsilon:
                    matches.append(i)

        if not matches:
            return np.nan  # No matches found, fail

        # Use only matched samples
        matched_samples = self.X[matches]

        # Create intervention data for T=T1 and T=T0
        data_t1 = self._create_intervention_data(T_idx, T1_value)
        data_t0 = self._create_intervention_data(T_idx, T0_value)

        # Set up common encoder parameters
        encoder_data = SimpleData(
            x=torch.zeros(1, len(self.nodes_list)),  # Placeholder, will be replaced in loop
            edge_index=self.graph_edges,
            edge_attr=torch.ones(self.graph_edges.size(1), 1) if self.graph_edges.size(1) > 0 else torch.zeros(0, 1),
            node_ids=torch.arange(len(self.nodes_list))
        )

        # For each matched sample, compute counterfactual under T=T1 and T=T0
        with torch.no_grad():
            # Initialize lists to store counterfactual outputs
            cf_y_t1 = []
            cf_y_t0 = []

            for i, sample in enumerate(matched_samples):
                # Encode the sample to get z (abduction)
                encoder_data.x = sample.unsqueeze(0)
                z, _ = self.model.model.encoder(
                    encoder_data.x,
                    encoder_data.edge_index,
                    edge_attr=encoder_data.edge_attr,
                    return_mean=True,
                    node_ids=encoder_data.node_ids,
                )

                # Prediction with T=T1
                x_cf_t1, _ = self.model.model.decoder(
                    z,
                    data_t1.edge_index,
                    edge_attr=data_t1.edge_attr,
                    intervention_mask=data_t1.intervention_mask,
                    intervention_values=data_t1.intervention_values,
                    return_type="sample",
                    node_ids=data_t1.node_ids,
                )

                # Store Y value from counterfactual
                cf_y_t1.append(x_cf_t1[0, Y_idx].item())

                # Prediction with T=T0
                x_cf_t0, _ = self.model.model.decoder(
                    z,
                    data_t0.edge_index,
                    edge_attr=data_t0.edge_attr,
                    intervention_mask=data_t0.intervention_mask,
                    intervention_values=data_t0.intervention_values,
                    return_type="sample",
                    node_ids=data_t0.node_ids,
                )

                # Store Y value from counterfactual
                cf_y_t0.append(x_cf_t0[0, Y_idx].item())

            # Convert to tensors
            cf_y_t1 = torch.tensor(cf_y_t1)
            cf_y_t0 = torch.tensor(cf_y_t0)

            # Calculate probabilities based on variable type
            if self.variable_type == "DISCRETE":
                # For discrete Y, compute P(Y=y) under each intervention
                p_t1 = torch.mean((cf_y_t1 == Y_value).float())
                p_t0 = torch.mean((cf_y_t0 == Y_value).float())
            else:
                # For continuous Y, use proximity to the target value
                epsilon = 0.1
                p_t1 = torch.mean((torch.abs(cf_y_t1 - Y_value) < epsilon).float())
                p_t0 = torch.mean((torch.abs(cf_y_t0 - Y_value) < epsilon).float())

            # Return difference in probabilities
            return (p_t1 - p_t0).item()

    def estimate(self, query, data, graph, index_to_variable):
        """
        Estimate the specified causal query using the VACA model.

        Args:
            query: The causal query to estimate (from causal-profiler)
            data: Dictionary of data (variable name -> values)
            graph: Dictionary representing the causal graph (parents -> children)
            index_to_variable: Mapping from indices to variable names

        Returns:
            float: The estimated value for the query
        """
        # Determine if we're dealing with discrete or continuous variables
        variable_type = "CONTINUOUS"
        for var_name in data:
            var_values = data[var_name]
            if np.all(np.mod(var_values, 1) == 0) and len(np.unique(var_values)) < 10:
                variable_type = "DISCRETE"
                break

        # Learn model only once
        if self.model is None:
            self._learn_model(data, graph, variable_type, index_to_variable)

        # Estimate the query based on its type
        if str(query.type) == "QueryType.ATE":
            return self._estimate_ATE(query)
        elif str(query.type) == "QueryType.CTF_TE":
            return self._estimate_CTF_TE(query)
        else:
            raise ValueError(f"Unsupported query type: {query.type}")
