import numpy as np
import torch
import networkx as nx
from torch.utils.data import DataLoader, TensorDataset
import lightning as L
import os
import sys

# Add the parent directory to the path to import from decaflow
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from decaflow.models import Encoder, Decoder, DeCaFlow
from decaflow.utils.metrics import compute_ate


class MyCausalMethod:
    """
    Implementation of causal estimation using DeCaFlow model
    """

    def __init__(
        self,
        num_hidden=2,
        flow_type="nsf",
        hidden_features=[64, 64, 64],
        max_epochs=100,
        batch_size=256,
        lr=1e-3,
    ):
        """
        Initialize the DeCaFlow method

        Args:
            num_hidden: Number of hidden confounders to model
            flow_type: Type of flow to use ('nsf', 'maf', etc.)
            hidden_features: Hidden layer sizes for neural networks
            max_epochs: Maximum training epochs
            batch_size: Batch size for training
            lr: Learning rate
        """
        self.num_hidden = num_hidden
        self.flow_type = flow_type
        self.hidden_features = hidden_features
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.model = None

    def graph_to_adjacency(self, adj_list, index_to_variable):
        """
        Convert an adjacency list from causal-profiler to an adjacency matrix for DeCaFlow

        Args:
            adj_list: Adjacency list from causal-profiler
            index_to_variable: Mapping from indices to variable names

        Returns:
            adjacency_matrix: torch.BoolTensor adjacency matrix
        """
        n = len(index_to_variable)
        adjacency_matrix = torch.zeros((n, n), dtype=torch.bool)

        # Fill the adjacency matrix based on the adjacency list
        for parent, children in adj_list.items():
            for child in children:
                adjacency_matrix[parent, child] = True

        return adjacency_matrix

    def prepare_data(self, data, index_to_variable):
        """
        Prepare data for training the DeCaFlow model

        Args:
            data: Dictionary of observational data
            index_to_variable: Mapping from indices to variable names

        Returns:
            train_loader: DataLoader for training
            n_features: Number of observed features
        """
        # Convert data to torch tensor
        data_array = np.stack([data[var] for var in index_to_variable], axis=1)
        data_tensor = torch.tensor(data_array, dtype=torch.float32)

        # Create dataset and dataloader
        dataset = TensorDataset(data_tensor)
        train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        # Get number of features
        n_features = data_tensor.shape[1]

        return train_loader, n_features, data_tensor

    def train_model(self, train_loader, n_features):
        """
        Train a DeCaFlow model on the given data

        Args:
            train_loader: DataLoader with training data
            n_features: Number of observed features

        Returns:
            model: Trained DeCaFlow model
        """
        try:
            # Create encoder with the adjacency matrix
            encoder = Encoder(
                flow_type=self.flow_type,
                num_hidden=self.num_hidden,
                adjacency=self.adj,
                features=self.num_hidden,
                context=n_features,
                hidden_features=[64, 64],
                activation=torch.nn.ReLU,
            )

            # Create decoder with the same adjacency matrix
            decoder = Decoder(
                flow_type=self.flow_type,
                num_hidden=self.num_hidden,
                adjacency=self.adj,
                features=n_features,
                context=self.num_hidden,
                hidden_features=self.hidden_features,
                activation=torch.nn.ReLU,
            )
        except Exception:
            print("Null jacobian")
            # Null jacobian, use order
            encoder = Encoder(
                flow_type=self.flow_type,
                num_hidden=self.num_hidden,
                order=list(
                    range(self.num_hidden)
                ),  # Topological order for hidden variables
                features=self.num_hidden,  # Number of latent variables to model
                context=n_features,  # Observed variables as conditions
                hidden_features=[64, 64],
                activation=torch.nn.ReLU,
            )

            decoder = Decoder(
                flow_type=self.flow_type,
                num_hidden=self.num_hidden,
                order=list(
                    range(n_features)
                ),  # Topological order for observed variables
                features=n_features,  # Observed variables
                context=self.num_hidden,  # Hidden confounders as conditions
                hidden_features=self.hidden_features,
                activation=torch.nn.ReLU,
            )

        # Create DeCaFlow model
        model = DeCaFlow(
            encoder=encoder,
            flow=decoder,
            regularize=True,
            warmup=20,
            lr=self.lr,
            optimizer_cls=torch.optim.Adam,
            scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau,
            scheduler_kwargs={
                "mode": "min",
                "factor": 0.95,
                "patience": 10,
                "verbose": False,
                "cooldown": 0,
            },
            scheduler_monitor="train_loss",
        )

        # Train the model
        trainer = L.Trainer(
            max_epochs=self.max_epochs,
            enable_checkpointing=False,
            logger=False,
            enable_progress_bar=False,
        )

        trainer.fit(model, train_loader)

        return model

    def estimate_ate(self, query, model, index_to_variable):
        """
        Estimate Average Treatment Effect (ATE) using the trained model

        Args:
            query: Causal query object from causal-profiler
            model: Trained DeCaFlow model
            index_to_variable: Mapping from indices to variable names

        Returns:
            ate: Estimated ATE
        """
        # Extract treatment and outcome variables from the query
        treatment_vars = query.vars["T"]
        outcome_vars = query.vars["Y"]

        # Get treatment values
        t_values = query.vars_values["T"]
        [t1_value], [t0_value] = t_values

        # Get indices for treatment and outcome in the model
        treatment_idx = (
            index_to_variable.index(treatment_vars[0].name) - self.num_hidden
        )
        outcome_idx = index_to_variable.index(outcome_vars[0].name) - self.num_hidden

        # Compute ATE using the model
        ate = compute_ate(
            flow=model,
            index=treatment_idx,
            value_a=float(t1_value),
            value_b=float(t0_value),
            num_samples=1000,
        )

        # Return the ATE for the outcome variable
        return ate[outcome_idx].item()

    def estimate_ctf_te(self, query, model, index_to_variable):
        """
        Estimate Counterfactual Treatment Effect using the trained model

        Args:
            query: Causal query object from causal-profiler
            model: Trained DeCaFlow model
            index_to_variable: Mapping from indices to variable names

        Returns:
            ctf_te: Estimated counterfactual treatment effect
        """
        # Extract treatment, outcome, and factual variables
        treatment_vars = query.vars["T"]
        outcome_vars = query.vars["Y"]
        factual_vars = query.vars["V_F"]

        # Get treatment values and factual values
        t_values = query.vars_values["T"]
        [t1_value], [t0_value] = t_values
        factual_value = query.vars_values["V_F"]
        y_value = query.vars_values["Y"]

        # Get indices for variables
        treatment_idx = index_to_variable.index(treatment_vars[0].name)
        outcome_idx = index_to_variable.index(outcome_vars[0].name)
        factual_indices = [index_to_variable.index(v.name) for v in factual_vars]

        # Get random samples for conditioning
        samples, _ = model.sample((1000,))

        # For each sample, compute counterfactual under both treatments
        cf_samples_t1 = []
        cf_samples_t0 = []

        # Keep only samples close to factual values (simple approximation)
        # This is a simplified version - in practice, better methods would be used
        for i in range(samples.shape[0]):
            sample = samples[i : i + 1]
            match = True

            # Check if the sample matches the factual values (approximately)
            for idx, val in zip(factual_indices, factual_value):
                if abs(sample[0, idx].item() - val) > 0.1:
                    match = False
                    break

            if match:
                # Compute counterfactuals for both treatment values
                cf_t1, _ = model.compute_counterfactual(
                    factual=sample, index=treatment_idx, value=float(t1_value)
                )
                cf_t0, _ = model.compute_counterfactual(
                    factual=sample, index=treatment_idx, value=float(t0_value)
                )

                cf_samples_t1.append(cf_t1[0, outcome_idx].item())
                cf_samples_t0.append(cf_t0[0, outcome_idx].item())

        if not cf_samples_t1:
            # If no matches found, use interventional distribution instead
            return np.nan

        # For continuous variables, return difference in expectations
        avg_t1 = sum(cf_samples_t1) / len(cf_samples_t1)
        avg_t0 = sum(cf_samples_t0) / len(cf_samples_t0)

        return avg_t1 - avg_t0

    def make_adjacency(self, adjacency_list, num_hidden):
        num_observed = len(adjacency_list)
        total = num_hidden + num_observed
        adj = torch.zeros((total, total), dtype=torch.bool)

        # Hidden variables: no incoming edges
        # But they point to all observed variables
        for h in range(num_hidden):
            for obs in range(num_hidden, total):
                adj[obs, h] = 1  # edge from h to obs -> mark in row of obs

        # Observed variables
        for src, targets in adjacency_list.items():
            for tgt in targets:
                # account for offset due to hidden vars
                adj[tgt + num_hidden, src + num_hidden] = 1

        # Add self-loops
        adj |= torch.eye(total, dtype=torch.bool)

        return adj

    def estimate(self, query, data, graph, index_to_variable):
        """
        Estimate the causal effect for a given query

        Args:
            query: Causal query from causal-profiler
            data: Observational data
            graph: Causal graph (as dictionary)
            index_to_variable: Mapping from indices to variable names

        Returns:
            estimate: Estimated causal effect
        """
        # Convert graph to adjacency matrix
        self.graph = graph
        self.adj = self.make_adjacency(
            graph, 0
        )  # hiddens already exist from evaluate.py

        # Prepare data for training
        train_loader, n_features, data_tensor = self.prepare_data(
            data, index_to_variable[self.num_hidden :]
        )

        # Train model if not already trained
        if self.model is None:
            self.model = self.train_model(train_loader, n_features)

        # Estimate causal effect based on query type
        if query.type.name == "ATE":
            return self.estimate_ate(query, self.model, index_to_variable)
        elif query.type.name == "CTF_TE":
            return self.estimate_ctf_te(query, self.model, index_to_variable)
        else:
            raise ValueError(f"Unsupported query type: {query.type}")
