import torch
from typing import Optional, Tuple
from dataclasses import dataclass
import torch.nn as nn
from torch.utils.data import Dataset

@dataclass
class PICalibData:
    """Trajectories available for calibration"""
    X: torch.Tensor
    Y: torch.Tensor
    error: Optional[torch.Tensor] = None  
    timesteps: Optional[torch.Tensor] = None 
    X_ctx: Optional[torch.Tensor] = None
    X_ctx_coeffs: Optional[torch.Tensor] = None
    Traj_len: Optional[torch.Tensor] = None
    mask: Optional[torch.BoolTensor] = None

class TrainLoader(Dataset):
    ## Dataset wrapper for sampled transitions
    def __init__(self, X_ctx_true: torch.Tensor, X_ctx_sim: torch.Tensor, 
                 errors: torch.Tensor):

        self._X_ctx_true = X_ctx_true
        self._X_ctx_sim = X_ctx_sim   
        self._errors = errors    

    def __len__(self):
        return self._X_ctx_true.shape[0]

    def __getitem__(self, index):
        return self._X_ctx_true[index], self._X_ctx_sim[index], self._errors[index]
    
class ValLoader(Dataset):
    ## Dataset wrapper for sampled transitions
    def __init__(self, X_ctx_true: torch.Tensor, X_ctx_sim: torch.Tensor, 
                 errors: torch.Tensor, Y: torch.Tensor, 
                 Y_pred: torch.Tensor):

        self._X_ctx_true = X_ctx_true
        self._X_ctx_sim = X_ctx_sim   
        self._errors = errors    
        self._Y = Y
        self._Y_pred = Y_pred

    def __len__(self):
        return self._X_ctx_true.shape[0]

    def __getitem__(self, index):
        return self._X_ctx_true[index], self._X_ctx_sim[index],\
              self._errors[index], self._Y[index], self._Y_pred[index] 

class FcModel(nn.Module):

    def __init__(self, input_dim, out_dim, hidden: Tuple=(), dropout: float = 0, dropout_at_first=False,
                 dropout_after_last=False, dropout_intermediate=False, tanh_after_last=False) -> None:

        nn.Module.__init__(self)    
        self._out_dim = out_dim
        if dropout is None:
            dropout = 0
        hidden_layers = []
        if len(hidden) > 0:
            for idx, layer in enumerate(hidden):
                hidden_layers.append(nn.ReLU())
                if dropout > 0 and dropout_intermediate:
                    hidden_layers.append(nn.Dropout(p=dropout))
                hidden_layers.append(nn.Linear(layer, hidden[idx+1] if idx < (len(hidden) - 1) else self._out_dim))
            stack = [nn.Linear(input_dim, hidden[0])] + hidden_layers
        else:
            stack = [nn.Linear(input_dim, self._out_dim)]

        if dropout > 0 and dropout_at_first:
            stack = [nn.Dropout(p=dropout)] + stack
        if tanh_after_last:
            ### Last tanh is important for best performance with Nueral CDE ####
            stack.append(nn.Tanh())
        if dropout > 0 and dropout_after_last:
            stack.append(nn.Dropout(p=dropout))
        self.linear_stack = nn.Sequential(*stack)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for layer in self.linear_stack:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0)  # Set biases to 0 for stability

    @property
    def output_dim(self):
        return self._out_dim

    def forward(self,):
        pass
        #return self.linear_stack(context)