import os, sys
import torch
import numpy as np

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(os.path.join(current_dir, os.pardir))
# Add the project root directory to the Python path
sys.path.insert(0, project_root)


from causal_nf.preparators.custom_preparator import CustomPreparator
from causal_nf.models.causal_nf import CausalNFightning
from causal_nf.modules.causal_nf import CausalNormalizingFlow


class MyCausalMethod:
    def __init__(self, device="cpu", max_epochs=100, lr=1e-3, scale="default"):
        self.device = device
        self.max_epochs = max_epochs
        self.lr = lr
        self.scale = scale

    def estimate(self, queries, data, graph, index_to_variable, space=None):
        # Determine if discrete or continuous
        discrete = False
        if space is not None:
            discrete = (
                str(getattr(space, "variable_type", None))
                == "VariableDataType.DISCRETE"
            )
        # 1. Prepare data and graph
        preparator = CustomPreparator(
            data=data,
            adjacency=graph,
            index_to_variable=index_to_variable,
            discrete=discrete,
            device=self.device,
            batch_size=128,
            scale=self.scale,
        )
        # 2. Build CausalNF model
        # For simplicity, use default architecture (can be made configurable)
        input_dim = preparator.x_dim()
        # Use a simple MAF flow from zuko (as in the original code)
        from zuko.flows import MAF

        # Create MAF flow directly
        flow = MAF(
            input_dim,  # features
            0,  # context
            transforms=2,  # number of transforms
            hidden_features=[16, 16],  # hidden features
            base_to_data=False,  # Use default flow direction
            base_distr="normal",  # Use normal distribution as base
            learn_base=False,  # Don't learn base
            activation=torch.nn.ReLU,  # Use ReLU activation
            adjacency=preparator.adjacency(),
        )

        model = CausalNormalizingFlow(flow=flow)
        model.set_adjacency(preparator.adjacency())

        # try:
        #     _ = model.flow()  # instantiate the underlying nn.Module
        # except Exception as e:
        #     raise e
        lightning_model = CausalNFightning(preparator, model, plot=False)
        lightning_model.to(self.device)
        # 3. Train model (simple loop, no validation)
        # Collect parameters from both lightning_model and model
        params = list(lightning_model.parameters())
        if hasattr(model, "flow") and hasattr(model.flow(), "parameters"):
            params += list(model.flow().parameters())
        if len(params) == 0:
            raise RuntimeError(
                "CausalNF model has no trainable parameters. Check model construction."
            )
        optimizer = torch.optim.Adam(params, lr=self.lr)
        dataloader = preparator.get_dataloader_train(batch_size=128)
        lightning_model.train()
        for epoch in range(self.max_epochs):
            for batch in dataloader:
                optimizer.zero_grad()
                output = lightning_model.forward(batch)
                loss = output["loss"].mean()
                loss.backward()
                optimizer.step()

        def estimate_query(query, index_to_variable):
            # 4. Map query to CausalNF inference
            # --- ATE ---

            if hasattr(query, "type") and str(query.type).upper().endswith("ATE"):
                # Extract Y, T, T1_value, T0_value from query
                # query.vars["T"] is a list of Variable objects
                T_var = query.vars["T"][0]
                T_index = index_to_variable.index(T_var.name)
                T1_value, T0_value = query.vars_values["T"]
                if not isinstance(T1_value, float):
                    T1_value = T1_value[0]
                    T0_value = T0_value[0]
                # For ATE, use compute_ate
                ate = model.compute_ate(
                    index=T_index,
                    a=float(T1_value),
                    b=float(T0_value),
                    num_samples=1000,
                    scaler=preparator.scaler_transform,
                )
                # Assume Y is the first variable in query.vars["Y"]
                Y_var = query.vars["Y"][0]
                Y_index = index_to_variable.index(Y_var.name)
                # ATE returns a vector for all variables except T, so pick Y_index (adjusted for T removal)
                # Remove T_index from the list to get the correct index
                indices = list(range(len(index_to_variable)))
                indices.remove(T_index)
                y_out_index = indices.index(Y_index)
                return (
                    ate[y_out_index].item()
                    if hasattr(ate, "item")
                    else float(ate[y_out_index])
                )
            # --- Ctf-TE ---
            elif (
                hasattr(query, "type") and str(query.type).upper() == "QUERYTYPE.CTF_TE"
            ):
                # Get variable indices
                Y_var = query.vars["Y"][0]
                Y_index = index_to_variable.index(Y_var.name)
                T_var = query.vars["T"][0]
                T_index = index_to_variable.index(T_var.name)
                V_F_vars = query.vars["V_F"]
                V_F_indices = [index_to_variable.index(v.name) for v in V_F_vars]
                T1_value, T0_value = query.vars_values["T"]
                if not isinstance(T1_value, float):
                    T1_value = T1_value[0]
                    T0_value = T0_value[0]
                V_F_value = query.vars_values["V_F"]
                Y_value = (
                    query.vars_values["Y"][0]
                    if isinstance(query.vars_values["Y"], (list, np.ndarray))
                    else query.vars_values["Y"]
                )

                # Sample from the model
                num_samples = 20000
                with torch.no_grad():
                    x_samples = model.sample((num_samples,))["x_obs"]
                    x_samples = lightning_model.input_scaler.inverse_transform(
                        x_samples, inplace=False
                    )
                # Filter for V_F = v_F (all variables must match)
                mask = np.ones(num_samples, dtype=bool)
                for idx, v_val in zip(V_F_indices, V_F_value):
                    mask &= np.isclose(
                        # x_samples[:, idx].cpu().numpy(), float(v_val), atol=1e-3
                        x_samples[:, idx].cpu().numpy(),
                        float(v_val),
                        atol=5e-1,
                    )

                if not np.any(mask):
                    return np.nan
                x_factual = x_samples[mask]
                # For each factual, compute Y_{do(T=t1)} and Y_{do(T=t0)}
                # Intervene on T with t1
                x_do_t1 = x_factual.clone()
                x_do_t1[:, T_index] = float(T1_value)
                x_cf_t1 = model.compute_counterfactual(
                    x_factual=x_factual,
                    index=T_index,
                    value=float(T1_value),
                    scaler=preparator.scaler_transform,
                )
                # Intervene on T with t0
                x_do_t0 = x_factual.clone()
                x_do_t0[:, T_index] = float(T0_value)
                x_cf_t0 = model.compute_counterfactual(
                    x_factual=x_factual,
                    index=T_index,
                    value=float(T0_value),
                    scaler=preparator.scaler_transform,
                )
                # For discrete: difference in probability Y=Y_value
                if discrete:
                    y_cf_t1 = x_cf_t1[:, Y_index].cpu().numpy()
                    y_cf_t0 = x_cf_t0[:, Y_index].cpu().numpy()
                    # If Y is multidimensional, handle accordingly
                    if (
                        isinstance(Y_value, (list, np.ndarray))
                        and len(np.shape(Y_value)) > 0
                    ):
                        matches_t1 = np.all(
                            np.isclose(y_cf_t1, Y_value, atol=1e-3), axis=-1
                        )
                        matches_t0 = np.all(
                            np.isclose(y_cf_t0, Y_value, atol=1e-3), axis=-1
                        )
                    else:
                        matches_t1 = np.isclose(y_cf_t1, Y_value, atol=1e-3)
                        matches_t0 = np.isclose(y_cf_t0, Y_value, atol=1e-3)
                    p_t1 = np.mean(matches_t1)
                    p_t0 = np.mean(matches_t0)
                    return float(p_t1 - p_t0)
                # For continuous: difference in expectation
                else:
                    y_cf_t1 = x_cf_t1[:, Y_index].cpu().numpy()
                    y_cf_t0 = x_cf_t0[:, Y_index].cpu().numpy()
                    return float(np.mean(y_cf_t1) - np.mean(y_cf_t0))
            else:
                raise NotImplementedError(
                    f"Query type {getattr(query, 'type', None)} not supported yet."
                )

        user_estimates = []
        for query in queries:
            try:
                estimate = estimate_query(query, index_to_variable)
            except Exception as e:
                print(f"      Query failed: {e}")
                estimate = np.nan
            user_estimates.append(estimate)
        return user_estimates
