from __future__ import annotations

from typing import List

import torch
from torch import Tensor, nn
from dowhy.gcm.auto import AssignmentQuality
from dowhy import gcm
import pandas as pd
import networkx as nx
from tqdm import trange

from rescue.models.causal_model.causal_inference import CausalInference

class CausalMeanVarSurrogateNN(nn.Module):  
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int
    ) -> None:
        """
        This NN was added for mapping the causal inference to include
        pytorch support and retain gradients.

        This surrogate is trained using data generated from a structural causal model 
        (SCM) via:
            - Interventional distribution: `E[Y | do(X = x)]`, `sqrt(V[Y | do(X = x)])`

        It retains the causal structure by being trained directly on samples from these 
        distributions, ensuring that the learned mapping respects the effects encoded 
        in the SCM.      
        
        Forward:
            Input:
                x (Tensor): Input tensor of shape `N x d` or `B x N x d`
                            - `N`: number of inputs
                            - `d`: input dimension
            Output:
                Tuple[Tensor, Tensor]:
                    - mean: Predicted mean of shape `N x m` or `B x N x m`
                    - var: Predicted standard deviation of shape `N x m` or `B x N x m`

        Notes:
            - The causal structure is preserved because this model is not fitted directly 
              on observed task outputs but rather on statistics derived from the SCM.
            - This model is compatible with both scalar and per-task variance modes.        

        Architecture:
            - Shared two-layer MLP
            - Two separate output heads:
                1. Mean head: predicts `m`-dimensional mean
                2. Log-var head: predicts `m`-dimensional standard deviation, 
                ensuring positivity and numerical stability

        Args:
            input_dim (int): Dimensionality of the input (number of causal variables `d`)
            output_dim (int): Number of output tasks/targets `m` (objectives)
        """        
        super().__init__()
        self.output_dim = output_dim
        self.input_dim = input_dim        
        self.shared = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )
        self.mean_head = nn.Linear(64, self.output_dim)
        self.var_head = nn.Linear(64, self.output_dim)

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        if x.size(-1) != self.input_dim:
            raise ValueError(
                f"Expected last dim {self.input_dim}, got {x.size(-1)}"
            )
        h = self.shared(x)
        mean = self.mean_head(h)
        var = self.var_head(h)
        return mean, var


class CausalNeuralNet:
    def __init__(
        self,
        causal_graph: nx.DiGraph,
        design_variables: List[str], 
        outcome_variables: List[str],
        train_x: Tensor | None = None,
        train_y: Tensor | None = None,
        observational_data: None | pd.DataFrame = None,
        perf_summary_causal_model: bool = False,
        causal_mechanism_quality: AssignmentQuality = AssignmentQuality.BETTER,
        verbose: bool = True,
        ) -> None:

        """
        Assign causal mechanisms and fit the causal model with the data.
        The fitted causal model is then used to generate data from 
        `interventional` or `observational` distribution of causal data
        generating mechanism. The generated data is used to train a causal
        neural network to predict the mean and variance.

        When `observational_data` is not provided the online training datapoints
        will be used. This can severely degrade the optimization performance. 
        However, performance will improve as we actively collectes samples
        from the Bayesian Optimization loop.   

        `causal_mechanism_quality`

        With "GOOD" quality:
            Numerical:
            - Linear Regressor
            - Linear Regressor with polynomial features
            - Histogram Gradient Boost Regressor

            Categorical:
            - Logistic Regressor
            - Logistic Regressor with polynomial features
            - Histogram Gradient Boost Classifier

        With "BETTER" quality:
            Numerical:
            - Linear Regressor
            - Linear Regressor with polynomial features
            - Gradient Boost Regressor
            - Ridge Regressor
            - Lasso Regressor
            - Random Forest Regressor
            - Support Vector Regressor
            - Extra Trees Regressor
            - KNN Regressor
            - Ada Boost Regressor

            Categorical:
            - Logistic Regressor
            - Logistic Regressor with polynomial features
            - Histogram Gradient Boost Classifier
            - Random Forest Classifier
            - Extra Trees Classifier
            - Support Vector Classifier
            - KNN Classifier
            - Gaussian Naive Bayes Classifier
            - Ada Boost Classifier                 

        Args:
            causal_graph (nx.digraph): The learned Directed Acyclic Graph (DAG)
            train_x (Tensor): A `N x d`-dim tensor of design variables.
            train_y (Tensor): A `N x m`-dim tensor of objetive variables.
            design_variables (List[str]): Names of input variables. **Order matters.**
                NOTE: Must be same order as pandas.column data to fit the causal model.
            outcome_variables (List[str]): Names of output (objective) variables. 
                **Order matters.**
                NOTE: Must be same order as pandas.column data to fit the causal model.    
            observational_data (pd.DataFrame): This could be data that was used to learn the
                DAG and/or offline observations. If not provided, the online training datapoints
                will be used.      
            perf_summary_causal_model (bool): If True, returns a summary of the performances of the 
                fitted mechanisms using the evaluate_causal_model method. If False, nothing is returned.
                NOTE: Check `dowhy` doc.

            causal_mechanism_quality (AssignmentQuality): AssignmentQuality for automatic model 
                selection and accuracy. Controls the type of model used and time spent on selection.
                NOTE: Check `dowhy` doc.

                - AssignmentQuality.GOOD  
                    Evaluates a small set of models.  
                    Model selection speed: Fast  
                    Model training speed: Fast  
                    Model inference speed: Fast  
                    Model accuracy: Medium

                - AssignmentQuality.BETTER  
                    Evaluates a larger set of models.  
                    Model selection speed: Medium  
                    Model training speed: Fast  
                    Model inference speed: Fast  
                    Model accuracy: Good

                - AssignmentQuality.BEST  
                    Uses AutoGluon (AutoML) with default settings.  
                    Model selection speed: Instant  
                    Model training speed: Slow  
                    Model inference speed: Slow–Medium  
                    Model accuracy: Best  
                    Requires the optional `autogluon.tabular` dependency.

            verbose (bool): Weather to show progress bar.
         """       

        self.design_variables = design_variables
        self.outcome_variables = outcome_variables
        self.verbose = verbose
        if observational_data is None and (train_x is None or train_y is None):
            raise ValueError(
                "Either `observational_data` or both `train_x` and `train_y` " \
                "must be provided."
            )
        if train_x is not None and train_y is not None:
            if train_x.shape[-1] != len(self.design_variables):
                raise ValueError(
                    f"`train_x` has {train_x.shape[-1]}D input, "
                    f"but expected  {len(self.design_variables)}D "
                    f"based on 'design_variables={self.design_variables}'. "
                    f"Ensure `train_x` data match the order and input dim."
                )          
            if train_y.shape[-1] != len(self.outcome_variables):
                raise ValueError(
                    f"`train_y` has {train_y.shape[-1]}D objective, "
                    f"but expected  {len(self.outcome_variables)}D "
                    f"based on 'outcome_variables={self.outcome_variables}'. "
                    f"Ensure `train_y` data match the order and objective dim."
                )      
        if type(causal_graph) is not nx.DiGraph:
            raise TypeError(
                f"Expected a networkx.DiGraph, but got {type(causal_graph)}."
            )              
        
        # Pre-processing the data for causal model fitting
        if observational_data is not None:
            if train_x is not None:
                # create pd.Dataframe using train_x, train_y
                # train_x column names: design_variables
                df_X = pd.DataFrame(train_x.detach().cpu().numpy(), 
                                    columns=self.design_variables)
                # train_y column names: objective variables
                df_Y = pd.DataFrame(train_y.detach().cpu().numpy(), 
                                    columns=self.outcome_variables)
                # {(train_x, train_y)}
                df_train = pd.concat([df_X, df_Y], axis=1)
                # Append data = {observational_data} ∪ {(train_x, train_y)}
                data = pd.concat([observational_data, df_train], ignore_index=True)
            else:
                data = observational_data
        else:
            # Create a pd.Dataframe using train_x, train_y
            # train_x column names: design_variables
            df_X = pd.DataFrame(train_x.detach().cpu().numpy(), 
                                columns=self.design_variables)
            # train_y column names: objective variables
            df_Y = pd.DataFrame(train_y.detach().cpu().numpy(), 
                                columns=self.outcome_variables)
            # data = {(train_x, train_y)}
            df_train = pd.concat([df_X, df_Y], axis=1)
            data = df_train   
        # Causal model fiting from DAG and data
        gcm.config.show_progress_bars = self.verbose
        self.scm = gcm.StructuralCausalModel(causal_graph)
        # Assign causal mechanisms
        gcm.auto.assign_causal_mechanisms(self.scm, data,
                                          causal_mechanism_quality)
        gcm.fit(self.scm, data, perf_summary_causal_model)


    def train_causal_net(
        self,   
        model: CausalMeanVarSurrogateNN,
        device: torch.device,
        x_intervention_val: Tensor,
        dtype: torch.dtype = torch.double,
        batch_processing: int = 10,
        num_intervention: int = 200,
        samples_per_intervention: int = 1000,
        epochs: int = 100,
        n_jobs: int = 4,
        backend: str = "loky",        
    ) -> tuple[nn.Module, float]:
        r"""
        Trains a causal surrogate neural network to predict the
        mean and variance of each objective using samples
        generated from a causal model.

        - The network learns from interventional samples. 
          The targets are the interventional mean `E[Y | do(X=x)]` 
          and variance `sqrt(V[Y | do(X=x)])` obtained from the 
          causal model.

        Args:
            model (nn.Module): A `CausalMeanVarSurrogateNN` that
                predicts both mean and variance.
            device (torch.device): Returned tensor's device.
            dtype (torch.dtype): Returned tensor's data type.
            x_intervention_val (Tensor): Interventional input samples.
                If `None`, Sobol sampling is used to generate
                interventional input samples. If provided, the shape must
                match the input shape of the model.
            batch_processing (int): Number of inputs to process per parallel job.
                A ratio of batch_size/input.shape[0] = 0.05 is recommended.
            num_intervention (int): Number of interventional
                input samples to generate using Sobol sampling.
            samples_per_intervention (int): Number of samples
                per intervention when `is_interventional=True`.
            epochs (int): Number of training epochs for the
                neural network.
            n_jobs (int): Number of parallel jobs to run.
                -1 means using all processors.
            backend (str): Backend for parallel processing.
                Options: "threading", "loky".
                For Linux, "loky" is always recommended.
                For input.shape[0] <= 250, "threading" is faster in windows.            

        Returns:
            nn.Module: A trained `CausalMeanSurrogateNN` that
                predicts both mean and variance and the loss.                                           
        """

        # Future me: num_intervention is not needed
        if num_intervention <= 0:
            raise ValueError("`num_intervention` must be a positive integer.")
        if samples_per_intervention < 10:
            raise ValueError("`samples_per_intervention` must be <= 10.")                          
        if epochs <= 0:
            raise ValueError("`epoch` must be a positive integer")
        
        device, dtype = device, dtype
        x_intervention_val = x_intervention_val.to(dtype=dtype, device=device)
        causal_sampler = CausalInference(
                            causal_model=self.scm,
                            design_variables=self.design_variables,
                            outcome_variables=self.outcome_variables,
                            device=device,
                            dtype=dtype
        )             
        y_intervention_outcome, y_std = causal_sampler.causal_interventions_parallel(
                            input=x_intervention_val,
                            samples_per_intervention=samples_per_intervention,
                            batch_processing=batch_processing,
                            backend=backend,
                            n_jobs=n_jobs,
                            verbose=self.verbose
                        )      
        # Check dimensions      
        if x_intervention_val.shape[-1] != len(self.design_variables):
            raise ValueError(
            f"`x_intervention_val` has last dimension {x_intervention_val.shape[-1]}D, "
            f"but expected {len(self.design_variables)}D "
            f"based on `design_variables={self.design_variables}`. "
            f"Possibly a bug in rescue."
            )
        if y_intervention_outcome.shape[-1] != len(self.outcome_variables):
            raise ValueError(
                f"`y_intervention_outcome` has last dimension {y_intervention_outcome.shape[-1]}D, "
                f"but expected {len(self.outcome_variables)}D "
                f"based on `outcome_variables={self.outcome_variables}`. "
                f"Possibly a bug in rescue."
            )       
        # Ensure consistent sample dimensions
        if x_intervention_val.shape[:-1] != y_intervention_outcome.shape[:-1] or \
           x_intervention_val.shape[:-1] != y_std.shape[:-1]:
            raise ValueError(
            f"Shape mismatch: expected all inputs to have the same sample dimensions, but got:\n"
            f"- x_intervention_val: {x_intervention_val.shape}\n"
            f"- y_intervention_outcome: {y_intervention_outcome.shape}\n"
            f"- y_std: {y_std.shape}"
            )
        # Model training
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        loss_fn = nn.MSELoss()
        progress = trange(
            epochs, 
            desc=f"Training {model.__class__.__name__}", 
            unit="epoch",
            disable=not self.verbose
        )
        for epoch in progress:
            model.train()
            # Forward pass: Given x_intervention_val, predict both 
            # y_intervention_outcome and std of y_std
            pred_mean, pred_std = model(x_intervention_val)
            # Compute MSE loss for both heads
            loss_mean = loss_fn(pred_mean, y_intervention_outcome)
            loss_std = loss_fn(pred_std, y_std)
            loss = loss_mean + loss_std
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # Update progress bar
            progress.set_postfix(loss=loss.item())
        model.eval()
        # freezes the model parameters
        # for param in model.parameters():
        #     param.requires_grad = False        
        return model, loss.item()      