#!/usr/bin/env python3

from __future__ import annotations

from typing import List, Tuple

import torch
from torch import Tensor

import pandas as pd
import networkx as nx
from tqdm import trange

from dowhy.gcm import StructuralCausalModel
from dowhy.gcm.auto import AssignmentQuality

from gpytorch.mlls import ExactMarginalLogLikelihood

from rescue.models.causal_model.map_to_NN import (
    CausalNeuralNet
)

def fit_causal_model(
    model: torch.nn.Module,
    causal_graph: nx.DiGraph,  
    design_variables: List[str], 
    outcome_variables: List[str],   
    device: torch.device,
    x_intervention_val: Tensor,
    dtype: torch.dtype = torch.double,  
    observational_data: None | pd.DataFrame = None, 
    train_x: Tensor | None = None,
    train_y: Tensor | None = None,      
    batch_processing: int = 10,
    n_jobs: int = 4,
    backend: str = "loky",
    num_intervention: int = 200,
    samples_per_intervention: int = 1000,
    epochs: int = 100,    
    perf_summary_causal_model: bool = False,
    causal_mechanism_quality: AssignmentQuality = AssignmentQuality.BETTER,
    verbose: bool = True       
) -> Tuple[torch.nn.Module, StructuralCausalModel, float]:
    
    r"""
    Assign causal mechanisms and fit the causal model with the
    data. The fitted causal model is then used to generate data
    from `interventional` or `observational` distribution of
    causal data generating mechanism. The generated data is used
    to train a causal neural network to predict the mean and
    variance.

    When `observational_data` is not provided, the online
    training datapoints will be used. This can severely degrade
    the optimization performance. However, performance will
    improve as we actively collect samples from the Bayesian
    Optimization loop.    

    - The network learns from interventional samples. 
      The targets are the interventional mean `E[Y | do(X=x)]` 
      and variance `sqrt(V[Y | do(X=x)])` obtained from the 
      causal model.

    Args:
        causal_graph (nx.DiGraph): The learned Directed Acyclic
            Graph (DAG)
        train_x (Tensor): A `N x d`-dim tensor of design
            variables.
        train_y (Tensor): A `N x m`-dim tensor of objective
            variables.
        design_variables (List[str]): Names of input variables.
            **Order matters.** NOTE: Must be same order as
            pandas.column data to fit the causal model.
        outcome_variables (List[str]): Names of output
            (objective) variables. **Order matters.** NOTE: Must
            be same order as pandas.column data to fit the causal
            model.    
        observational_data (pd.DataFrame): This could be data
            that was used to learn the DAG and/or offline
            observations. If not provided, the online training
            datapoints will be used.        
        device (torch.device): Returned tensor's device.
        dtype (torch.dtype): Returned tensor's data type.
        batch_processing (int): Batch size for computing
            interventions (only used if interventional).
        n_jobs (int): Number of parallel jobs to run for
            computing interventions (only used if
            interventional).
        backend (str): Backend to use for parallel processing.
            Options are "loky", "multiprocessing", or "threading".
        num_intervention (int): Number of interventional input
            samples to generate using Sobol sampling when
            `is_interventional=True`.
        samples_per_intervention (int): Number of samples per
            intervention when `is_interventional=True`.
        epochs (int): Number of training epochs for the neural
            network.
        perf_summary_causal_model (bool):
        causal_mechanism_quality (AssignmentQuality):
        verbose (bool, True): 

    Returns:
        Tuple(CausalMeanVarSurrogateNN, StructuralCausalModel, float): 
            A trained `CausalMeanSurrogateNN` that
            predicts both mean and variance using the SCM,
            The Structural Causal Model (SCM), and
            the loss of the `CausalMeanSurrogateNN`.
   """         

    causal_nn = CausalNeuralNet(
                causal_graph=causal_graph,
                train_x=train_x, 
                train_y=train_y, 
                design_variables=design_variables, 
                outcome_variables=outcome_variables, 
                observational_data=observational_data,
                perf_summary_causal_model=perf_summary_causal_model,
                causal_mechanism_quality=causal_mechanism_quality,
                verbose=verbose
    )
    causal_net, loss = causal_nn.train_causal_net(
            model=model,
            device=device,
            dtype=dtype,
            x_intervention_val=x_intervention_val,
            batch_processing=batch_processing,
            n_jobs=n_jobs,
            backend=backend,
            num_intervention=num_intervention,
            samples_per_intervention=samples_per_intervention,
            epochs=epochs,
    )
    return causal_net, causal_nn.scm, loss


def fit_causal_gp(
    mll: ExactMarginalLogLikelihood,
    epochs: int = 500,
    verbose: bool = True
) -> None: 
    """_summary_

    Args:
        mll (ExactMarginalLogLikelihood): The marginal log likelihood.
        epochs (int, optional): The number of training epochs. Defaults to 250.
        verbose (bool, optional): . Defaults to True.

    Returns:
        The fitted mll   
    """

    # Find optimal model hyperparameters
    mll.model.train()
    mll.model.likelihood.train() 
    optimizer = torch.optim.Adam(mll.model.parameters(), lr=0.1)   
    progress = trange(
        epochs, 
        desc=f"Fitting {mll.model.__class__.__name__}", 
        unit="epoch",
        disable=not verbose
    )    
    # "Loss" for GPs - the marginal log likelihood
    for epoch in progress:
        optimizer.zero_grad()
        output = mll.model(mll.model.train_inputs[0])
        loss = -mll(output, mll.model.train_targets)
        loss.backward()
        optimizer.step()
        # scheduler.step(loss)
        progress.set_postfix(loss=loss.item())
    mll.model.training = False
    mll.model.likelihood.training = False