#!/usr/bin/env python3

from __future__ import annotations

from typing import List, Tuple

import torch
from torch import Tensor
from dowhy import gcm
from dowhy.gcm import StructuralCausalModel
from tqdm import trange
from joblib import Parallel, delayed
from tqdm import tqdm
import warnings


class CausalInference():
    def __init__(self,
        causal_model: StructuralCausalModel,
        design_variables: List[str], 
        outcome_variables: List[str],
        device: torch.device,
        dtype: torch.dtype = torch.double    
    ) -> None:
        r"""
        Utility class to sample from a fitted causal model,
        supporting both interventional and observational
        settings.

        Let `m` be the number of objectives (outputs), `d` the
        number of input variables, and `N` the number of samples.
        This class generates `X ∈ ℝ^{N×d}` and corresponding
        outputs `Y ∈ ℝ^{N×m}` along with per-output variance
        estimates.

        Args:
            causal_model (StructuralCausalModel): A fitted
                DoWhy-GCM structural causal model.
            design_variables (List[str]): Ordered list of input
                variable names (e.g., ["x1", "x2"]).
                NOTE: **Order matters.** Must be same order as
                pandas.column data to fit the causal model.
            outcome_variables (List[str]): Ordered list of
                output variable names (e.g., ["obj_1", "obj_2"]).
                NOTE: **Order matters.** Must be same order as
                pandas.column data to fit the causal model.
            device (torch.device): Device to which returned
                tensors will be moved.
            dtype (torch.dtype): Desired floating-point
                precision.

        Examples:
            >>> causal_graph = gcm.StructuralCausalModel(networkx.digraph)
            >>> causal_model = Rescue.models.fit_causal_model(pandas.DataFrame, causal_graph)  
            >>> sampler = CausalSampler(causal_model, ["x1", "x2"], ["obj_1", "obj_2"], 
                                        torch.device, torch.dtype)
            >>> X_obs, Y_obs, Y_std_obs = sampler.causal_observations(1000)
            >>> Y_obs.shape  # torch.Size([1000, 2])
            >>> X_obs.shape  # torch.Size([1000, 2])
            >>> X_obs.shape  # torch.Size([1000, 2])
            >>> bounds = torch.stack([torch.zeros(2), torch.ones(2)])
            >>> X_intvn = draw_sobol_samples(bounds=bounds, 
                                             n=500, q=1).squeeze(1).to(device=device, dtype=dtype)
            >>> Y_intvn, Y_std_intv = sampler.causal_interventions(X_intvn, 100, 10)
            >>> X_intvn.shape       # torch.Size([500, 2]
            >>> Y_intvn.shape       # torch.Size([500, 2])
            >>> Y_std_intv.shape    # torch.Size([500, 2])
        """

        self.causal_model = causal_model  
        self.design_variables = design_variables
        self.outcome_variables = outcome_variables
        self.device, self.dtype = device, dtype      

    def causal_interventions(
        self, 
        input: Tensor,   
        samples_per_intervention: int = 1000,
        batch_processing: int = 10,
        verbose: bool = True 
    ) -> Tuple[Tensor, Tensor]:
        r"""
        Computes the interventional mean and variance of the
        target variables using do-calculus with the fitted causal
        model.

        This function estimates:
        - `E[Y | do(X = x)]`: the interventional mean
        - `sqrt(V[Y | do(X = x)])`: the interventional variance
          deviation

        The interventional distribution is computed using
        DoWhy-GCM’s `gcm.interventional_samples`.

        Args:
            input (Tensor): Input tensor of shape `N x d`
                representing `N` points of `d`-dimensional inputs
                to intervene on (i.e., `do(X = x)`).
            samples_per_intervention (int): Number of
                interventional samples to draw per input. Higher
                values improve accuracy but increase cost.
            batch_processing (int): Number of inputs to process
                per batch. Larger values speed up computation but
                require more memory.

        Returns:
            Tuple:
                - `Y`: Tensor of shape `N x m` with
                  interventional means.
                - `Y_std`: Tensor of shape `N x m` with
                  interventional variance.
        """

        if input.shape[1] != len(self.design_variables):
            raise ValueError(
                f"Recevied {input.shape[1]}D input, "
                f"but expected  {len(self.design_variables)}D "
                f"based on 'design_variables={self.design_variables}'. "
                f"Ensure columns match in order and count."
            )  
        if samples_per_intervention <= 0:
            raise ValueError(
                "`samples_per_intervention` must "
                "be a positive integer. "
                f"Received {samples_per_intervention}."
            )
                          
        num_samples = input.shape[0]
        causal_means_tensor = torch.zeros(
            (num_samples, len(self.outcome_variables)), 
            device=self.device,
            dtype=self.dtype
            )
        causal_stds_tensor = torch.zeros(
            (num_samples, len(self.outcome_variables)), 
            device=self.device,
            dtype=self.dtype    
            )
        # Process in batches
        # num_samples/batch_processing
        progress = trange(
            0,
            num_samples,
            batch_processing,
            desc=f"Performing causal interventions",
            unit="batch",
            disable=not verbose
        )   
        for batch_start in progress:
            batch_end = min(batch_start + batch_processing, num_samples)
            batch_values = input[batch_start:batch_end]
            # Create interventions for the entire batch
            interventions_list = [
                {
                    variable: (lambda x, val=val: val) 
                    for variable, val in zip(self.design_variables, 
                                                    values.tolist())
                }
                for values in batch_values
            ]
            # Compute interventional samples for the batch
            batch_samples = [
                gcm.interventional_samples(
                    self.causal_model, interventions, 
                    num_samples_to_draw=samples_per_intervention
                    )
                for interventions in interventions_list
            ]
            # Store mean results for each sample in the batch
            for i, samples in enumerate(batch_samples):
                y_vals = samples[self.outcome_variables]
                causal_means_tensor[batch_start + i] = torch.tensor(
                    y_vals.mean().values, device=self.device
                )
                causal_stds_tensor[batch_start + i] = torch.tensor(
                    y_vals.std().values, device=self.device
                )      
            progress.set_postfix(interventions=f"{batch_end}/{num_samples}")                           
        return causal_means_tensor, causal_stds_tensor


    def causal_interventions_parallel(
        self,
        input: torch.Tensor,
        samples_per_intervention: int = 1000,
        batch_processing: int = 30,
        n_jobs: int = 4,
        backend: str = "loky",
        verbose: bool = True,
    ):
        r"""
        Computes the interventional mean and variance of the
        target variables using batched parallel execution.

        This function estimates:
        - `E[Y | do(X = x)]`: the interventional mean
        - `sqrt(V[Y | do(X = x)])`: the interventional variance
          deviation

        The interventional distribution is computed using
        DoWhy-GCM’s `gcm.interventional_samples`.        

        Args:
            input (Tensor): Input tensor of shape `N x d`.
            samples_per_intervention (int): Number of samples per intervention.
            batch_processing (int): Number of inputs to process per parallel job.
                A ratio of batch_size/input.shape[0] = 0.05 is recommended.
            n_jobs (int): Number of parallel jobs to run.
                -1 means using all processors.
            backend (str): Backend for parallel processing.
                Options: "threading", "loky".
                For Linux, "loky" is always recommended.
                For input.shape[0] <= 250, "threading" is faster in windows.
            verbose (bool): Show simple progress prints.

        Returns:
            Tuple:
                - Y: Tensor of shape `N x m` with interventional means.
                - Y_std: Tensor of shape `N x m` with interventional std deviations.
        """
        if input.shape[1] != len(self.design_variables):
            raise ValueError(
                f"Received {input.shape[1]}D input, "
                f"but expected {len(self.design_variables)}D "
                f"based on 'design_variables={self.design_variables}'. "
                f"Ensure columns match in order and count."
            )

        if samples_per_intervention <= 0:
            raise ValueError(
                "`samples_per_intervention` must be a positive integer. "
                f"Received {samples_per_intervention}."
            )
        
        if samples_per_intervention < 1000:
            warnings.warn(
                "Using less than 1000 samples per intervention may lead to "
                "inaccurate estimates of the interventional mean and variance."
            )

        num_samples = input.shape[0]
        num_objectives = len(self.outcome_variables)

        causal_means_tensor = torch.zeros(
            (num_samples, num_objectives), 
            device=self.device,
            dtype=self.dtype
        )
        causal_stds_tensor = torch.zeros(
            (num_samples, num_objectives), 
            device=self.device,
            dtype=self.dtype
        )

        # Prepare batches
        batches = [
            list(enumerate(input))[i : i + batch_processing]
            for i in range(0, num_samples, batch_processing)
        ]

        # Parallel execution over batches       
        all_results = []
        processed_inputs = 0

        with tqdm(
            total=len(batches),
            desc="Performing causal interventions",
            unit="batch",
            disable=not verbose
        ) as pbar:        

            parallel = Parallel(
                n_jobs=n_jobs,
                backend=backend,
                return_as='generator'
            )

            for batch_results in parallel(
                delayed(self._process_causal_intervention_batch)(
                    batch_idx, 
                    batch,
                    self.design_variables,
                    self.outcome_variables,
                    self.causal_model,
                    self.device,
                    samples_per_intervention
                )
                for batch_idx, batch in enumerate(batches)
            ):
                # Collect results
                all_results.append(batch_results)

                # Update processed inputs live
                processed_inputs += len(batch_results)
                pbar.set_postfix({
                    "interventions": f"{processed_inputs}/{input.shape[0]}"
                })
                pbar.update(1)

        # Flatten results
        for batch_results in all_results:
            for idx, mean, std in batch_results:
                causal_means_tensor[idx] = mean
                causal_stds_tensor[idx] = std
        return causal_means_tensor, causal_stds_tensor
    

    @staticmethod
    def _process_causal_intervention_batch(
        batch_idx: int, 
        batch: List[Tuple[int, Tensor]],
        design_variables: List[str],
        outcome_variables: List[str],
        causal_model: StructuralCausalModel,
        device: torch.device,
        samples_per_intervention: int
    ):
        r""" 
        Process a single batch of interventional samples.
        This function is called in parallel for each batch.

        Args:
            batch_idx (int): Index of the current batch.
            batch (List[Tuple[int, Tensor]]): List of tuples
                containing indices and input values for the batch.
            design_variables (List[str]): List of design variable names.
            outcome_variables (List[str]): List of objective variable names.
            causal_model (StructuralCausalModel): Fitted causal model.
            device (torch.device): Device for tensor operations.
            samples_per_intervention (int): Number of samples per intervention.
        """
        batch_len = len(batch)
        results = [None] * batch_len  # Pre-allocate list

        for i, (idx, values) in enumerate(batch):
            interventions = {
                variable: (lambda x, val=val: val)
                for variable, val in zip(design_variables, values.tolist())
            }
            samples = gcm.interventional_samples(
                causal_model,
                interventions,
                num_samples_to_draw=samples_per_intervention,
            )
            y_vals = samples[outcome_variables]
            mean = torch.tensor(y_vals.mean().values, device=device)
            std = torch.tensor(y_vals.std().values, device=device)
            results[i] = (idx, mean, std)          
        return results
    

    def causal_observations(
        self,
        causal_observational_samples: int = 10000,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Generates observational samples from a fitted causal
        model and returns input-output tensors.

        This function does not perform any interventions
        (i.e., no `do` operator), and is therefore much faster
        than `causal_interventions`. The mean and variance are
        computed from the observational distribution:
        - Mean: `E[Y | X = x]`
        - Variance: `sqrt(V[Y | X = x])`

        Args:
            causal_observational_samples (int): Number of
                samples to draw from the causal model.

        Returns:
            Tuple:
                - `X`: Tensor of shape `N x d` representing the
                  design inputs.
                - `Y_mean`: Tensor of shape `N x m` representing
                  the objective means.
                - `Y_std`: Tensor of shape `N x m` representing
                  the objective standard deviations.
        """
        # TODO: No validation check on design variables dim 
        if causal_observational_samples <= 0:
            raise ValueError(
                "`causal_observational_samples` must be a "
                "positive integer. "
                f"Received {causal_observational_samples}."
            )        
        # Sample from causal model
        df_samples = gcm.draw_samples(causal_model=self.causal_model, 
                                      num_samples=causal_observational_samples)
        # Ensure correct input/output ordering
        try:
            X = df_samples[self.design_variables].values
            Y = df_samples[self.outcome_variables].values
        except KeyError as e:
            raise ValueError(f"Missing expected column in causal samples: {e}")
        # Convert to torch tensors
        X_tensor = torch.tensor(X, device=self.device, dtype=self.dtype)
        # Conditional mean, E[Y | X]
        Y_tensor = torch.tensor(Y, device=self.device, dtype=self.dtype)
        # Compute per-output std
        # These are marginal variance, not conditioned on X
        # Meaning sqrt[ V[Y] ], not sqrt[ V[Y | X] ]
        # TODO: Make it granular (conditional) with the cost of computation
        #       Instead just use the `causal_interventions`
        Y_std = Y_tensor.std(dim=0, keepdim=True).repeat(Y_tensor.size(0), 1)         
        return X_tensor, Y_tensor, Y_std