import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from einops import rearrange

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def modulate(x, shift, scale):
    return x * (1 + scale) + shift

def apply_mask(x, mask=None) :
    if mask is None :
        return x
    else :
        return x * mask

def build_mlp(input_dim: int,
              hidden_dim: int,
              output_dim: int,
              num_layers: int,
              dropout_rate: float = 0.0,
              activation: nn.Module = nn.ReLU(),
              final_activation: nn.Module = None) -> nn.Sequential:
    
    layers = []
    
    if num_layers == 0 :
        layers.append(nn.Identity())

    elif num_layers == 1 :
        layers.append(nn.Linear(input_dim, output_dim))

    else :
        # Input layer
        layers.extend([
            nn.Linear(input_dim, hidden_dim),
            activation,
            nn.Dropout(dropout_rate)
        ])
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                activation,
                nn.Dropout(dropout_rate)
            ])
        
        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))
    
    # Add final activation if specified
    if final_activation is not None:
        layers.append(final_activation)
    
    return nn.Sequential(*layers)

def positional_encoding_1d(length, d_model):
    # Initialize position indices and compute div_term
    position = torch.arange(length, dtype=torch.float32).unsqueeze(1)  # Shape: (length, 1)
    div_term = torch.exp(-torch.arange(0, d_model, 2, dtype=torch.float32) * (torch.log(torch.tensor(10000.0)) / d_model))
    pe = torch.zeros((length, d_model), dtype=torch.float32)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe # (length, d_model)

def load_roi_coordinates(roi, normalize=False):

    if "450" in roi :
        coord_file = f'./data2/roi/coordinates/{roi}_coordinates.csv'        
    else :
        coord_file = f'./data/roi/coordinates/{roi}_coordinates.csv'
    coords = pd.read_csv(coord_file)
    xyz_coordinates = coords[['R', 'A', 'S']].values
    if isinstance(xyz_coordinates, np.ndarray):
        xyz_coordinates = torch.tensor(xyz_coordinates, dtype=torch.float32)
    else:
        xyz_coordinates = torch.tensor(xyz_coordinates.values, dtype=torch.float32)

    if normalize :
        x_min, x_max = xyz_coordinates[:,0].min(), xyz_coordinates[:,0].max()
        y_min, y_max = xyz_coordinates[:,1].min(), xyz_coordinates[:,1].max()
        z_min, z_max = xyz_coordinates[:,2].min(), xyz_coordinates[:,2].max()

        xyz_coordinates[:,0] = (xyz_coordinates[:,0] - x_min) / (x_max - x_min) * (1. - (-1.)) + (-1.)
        xyz_coordinates[:,1] = (xyz_coordinates[:,1] - y_min) / (y_max - y_min) * (1. - (-1.)) + (-1.)
        xyz_coordinates[:,2] = (xyz_coordinates[:,2] - z_min) / (z_max - z_min) * (1. - (-1.)) + (-1.)

    return xyz_coordinates

def get_roi_coordinates_positions(roi, hidden_dim, max_len=100):

    pos_enc = positional_encoding_1d(max_len, hidden_dim)
    xyz_coordinates = load_roi_coordinates(roi, normalize=False)

    x_pos = pos_enc[xyz_coordinates[:,0].long()]
    y_pos = pos_enc[xyz_coordinates[:,1].long()]
    z_pos = pos_enc[xyz_coordinates[:,2].long()]

    xyz_pos = torch.cat((x_pos,y_pos,z_pos), dim=1)
    
    return xyz_pos

PAD = 0

def get_attn_key_pad_mask(seq_k, seq_q):
    """ For masking out the padding part of key sequence. """

    # expand to fit the shape of key query attention matrix
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(PAD)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk x dim
    return padding_mask

def get_self_attention_mask(mask):
    # First create a (B, T, 1) tensor
    mask_expanded = mask.unsqueeze(-1)
    mask_expanded_other = mask.unsqueeze(1)
    
    # Combine the masks - we can only attend if both positions are valid
    attention_mask = mask_expanded & mask_expanded_other
    
    # Invert the mask for the masked_fill operation
    # True values will be filled with -1e9
    return ~attention_mask

def get_cross_attention_mask(query_mask, key_mask):
    # (B, T_q, 1) and (B, 1, T_k)
    query_mask_expanded = query_mask.unsqueeze(-1)
    key_mask_expanded = key_mask.unsqueeze(1)
    
    # Combine masks - we can only attend if both positions are valid
    attention_mask = query_mask_expanded & key_mask_expanded
    
    # Invert for masked_fill
    return ~attention_mask

def get_subsequent_mask(seq, current=True):
    """ For masking out the subsequent info, i.e., masked self-attention. """
    seq = seq.reshape(seq.size(0), seq.size(1), -1)

    sz_b, len_s, _ = seq.size()

    if current:
        # Include datas at current times for attention
        subsequent_mask = torch.triu(
            torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    else:  
        # Exclude datas at current times for attention
        subsequent_mask = torch.triu(
            torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=0)
        
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls
    return subsequent_mask

def forward_fill_3d(t: torch.Tensor) -> torch.Tensor:
    n_batch, t_dim, n_dim = t.shape  # Handle 3D tensor: (batch, time, dim)
    rng = torch.arange(t_dim)
    rng_3d = rng.unsqueeze(0).unsqueeze(2).repeat(n_batch, 1, n_dim)
    rng_3d[t == 0] = 0
    idx = rng_3d.cummax(1).values
    filled_t = t[torch.arange(n_batch)[:, None, None], idx, torch.arange(n_dim)[None, None, :]]
    return filled_t

def fill_zero_padding(data, mask):
    B, t, D = data.shape
    _, T = mask.shape
    
    transformed = torch.zeros(B, T, D, device=data.device, dtype=data.dtype)
    
    for b in range(B):
        data_flat = data[b].reshape(-1, D)  # Shape: [t, D]
        true_indices = torch.nonzero(mask[b]).long()  # Shape: [num_true]
        assert len(true_indices) == t, f"Number of True positions ({len(true_indices)}) doesn't match data size ({t})"
        for idx, i in enumerate(true_indices):
            transformed[b, i] = data_flat[idx]
    return transformed

class Decoder(nn.Module):
    """ A sequence to sequence model with attention mechanism. """

    def __init__(self, args):
        super().__init__()

        n_layer_obs_decoder = args.get("n_layer_obs_decoder", 2)
        
        self.ld = args.state_dim
        self.od = args.input_dim
        
        self.decoder = build_mlp(self.ld, min(self.ld, self.od), self.od, n_layer_obs_decoder)
        
    def forward(self, input):
        out = self.decoder(input)
        return out
    

class TemporalEmbedder(nn.Module) :

    def __init__(self, dim_input, dim_embedding, n_layer_obs_embedder=1, scale_factor=10000, mask_fill_mode="zero") :

        super().__init__()

        assert dim_embedding % 2 == 0, "Dimension d must be even."

        self.projection = build_mlp(dim_input, dim_embedding, dim_embedding, n_layer_obs_embedder)
        self.mask_fill_mode = mask_fill_mode

        if mask_fill_mode == "token" :
            self.mask_token = nn.Parameter(torch.randn(dim_embedding))

        self.d = dim_embedding
        self.scale_factor = scale_factor

    def scalar_to_positional_encoding(self, x):
        
        # Get original shape and prepare for broadcasting
        original_shape = x.shape
        x_flat = x.view(-1, 1)
        
        # Create position encodings
        positions = torch.arange(0, self.d // 2, device=x.device)
        div_term = self.scale_factor ** (2 * positions / self.d)
        
        # Calculate sin and cos terms
        # Shape: [*, d//2] for each
        sin_term = torch.sin(x_flat / div_term)
        cos_term = torch.cos(x_flat / div_term)
        
        # Interleave sin and cos terms
        # Shape: [*, d]
        encoding = torch.zeros(*x_flat.shape[:-1], self.d, device=x.device)
        encoding[..., 0::2] = sin_term
        encoding[..., 1::2] = cos_term
        
        # Reshape back to original dimensions + d
        final_shape = original_shape + (self.d,)
        encoding = encoding.view(final_shape)

        return encoding

    def forward(self, x, times, mask=None) :
        
        B, T, D = x.shape

        x = rearrange(x, 'b t d -> (b t) d')
        x = self.projection(x).reshape(B, T, -1)
        pos_embeddings = self.scalar_to_positional_encoding(times)
        x = x + pos_embeddings

        if mask is not None and self.mask_fill_mode == "token" : # context mode & mask is needed (otherwise, no need to mask)
            mask_exp = mask.unsqueeze(-1)
            mask_tok = self.mask_token.view(1, 1, D)  
            x = torch.where(mask_exp, x, mask_tok + pos_embeddings)

        return x