"""
References:
1. https://github.com/lucidrains/vector-quantize-pytorch
2. https://huggingface.co/docs/transformers/model_doc/patchtst
3. https://github.com/AI4HealthUOL/SSSD-ECG
4. https://github.com/karpathy/nanoGPT
5. https://github.com/helme/ecg_ptbxl_benchmarking
6. https://github.com/PKUDigitalHealth/HeartLang
7. https://github.com/bakqui/ST-MEM
"""

import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.cuda.amp as amp
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import utils as utils
# from positional_encodings.torch_encodings import PositionalEncoding1D, Summer # (batch, feature, channel)

from transformers import PatchTSTConfig, PatchTSTModel
from decoder.component.mlp import MLP
from decoder.s4.S4D import S4D
from Quantization import Two_scale_manifold_alignment


class PositionEmbedding(nn.Module):
    """
    Standard sinusoidal position embedding
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
    
    
class S4(nn.Module):
    """
    State Space Model (S4) layer for decoding.
    """
    def __init__(
        self,
        d_input,
        d_model,
        n_layers=6,
        dropout=0.1,
        d_state=64,
    ):
        """
        Args:
            d_input (int): Input channel size
            d_model (int): Model hidden dimension
            n_layers (int): Number of S4D layers
            dropout (float): Dropout rate
            d_state (int): SSM state expansion factor
        """
        super(S4, self).__init__()

        self.d_input = d_input
        self.d_model = d_model
        self.pre_proj = nn.Linear(self.d_input, self.d_model)
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()

        for i in range(n_layers):
            self.s4_layers.append(
                S4D(self.d_model, dropout=dropout, transposed=True, d_state=d_state, lr=min(0.001, 0.01))
            )
            # Add normalization layer every other layer
            if i % 2 == 0:
                self.norms.append(nn.BatchNorm1d(self.d_model))
            else:
                self.norms.append(None)  # Placeholder for no normalization
            self.dropouts.append(nn.Dropout1d(dropout))

    def forward(self, x):
        # Input x shape: (batch_size, seq_len, channels)
        x = self.pre_proj(x) 
        
        # Transpose for S4 layers: (batch_size, channels, seq_len)
        x = x.transpose(-1, -2)

        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            z, _ = layer(x) # S4 layer forward pass
            z = dropout(z)  # Apply dropout
            x = z + x       # Residual connection
            if norm is not None:
                x = norm(x) # Apply normalization
            
        # Transpose back to (batch_size, seq_len, channels)
        return x.transpose(-1, -2)


class AtomECG(nn.Module):
    """
    AtomECG architecture
    Args:
        num_input_channels: Number of input channels
        context_length: Length of the input sequence (e.g., 5000)
        patch_length: Length of each patch
        patch_stride: Stride for patch creation
        num_embeddings: Number of embeddings in the codebook
        embedding_dim: Dimension of each codebook embedding
        vq_heads: Number of quantizers (used in residual VQ)
    """
    def __init__(self, 
                 num_input_channels,
                 context_length,
                 patch_length,
                 patch_stride,
                 num_embeddings, 
                 embedding_dim, 
                 vq_heads,
                 ):
        super(AtomECG, self).__init__()

        # PatchTST parameters
        self.num_input_channels = num_input_channels
        self.context_length = context_length
        self.patch_length = patch_length
        self.patch_stride = patch_stride
        
        # VQ parameters
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.vq_heads = vq_heads
        
        # Encoder: PatchTST-based
        self._encoder = PatchTSTModel(PatchTSTConfig(
            num_input_channels=self.num_input_channels,
            context_length=self.context_length,
            patch_length=self.patch_length,
            patch_stride=self.patch_stride,
            channel_attention=True,
            do_mask_input=True,
            channel_consistent_masking=False, # Mask all channels at the same position or not
            mask_type='random',
            random_mask_ratio=0.4,
            num_forecast_mask_patches=124, # (context_length - patch_length) // patch_stride + 1
            mask_value=0, 
            use_cls_token=False,
        ))
  
        self._vq = Two_scale_manifold_alignment(
            dim=self.embedding_dim,
            num_quantizers=self.vq_heads,
            codebook_size=self.num_embeddings,
            rotation_trick=True,
            commitment_weight=0.5,
        )

        # Decoder: S4-based
        self._decoder = S4(d_input=12, d_model=64)

        # Normalization and Projection layers
        self.BN = nn.BatchNorm2d(12)    
        self.BN1d = nn.BatchNorm1d(12)
        
        self._map = nn.Sequential(
            MLP(input_dim=128, output_dim=64, activation='linear', n_activations=1),
            MLP(input_dim=64, output_dim=8, activation='linear', n_activations=1)
        )
        
        self._projection = MLP(input_dim=64, output_dim=12, activation='linear', n_activations=1)
           

    def calculate_perplexity(self, indices):
        """
        Calculates the perplexity of the codebook usage.
        Perplexity is a measure of how uniformly the codebook embeddings are being used.
        A higher perplexity is generally better, indicating diverse code usage.
        """
        perplexity_list = []
        
        # For multi-head VQ, calculate perplexity for each head
        if self.vq_heads and self.vq_form in ['residual', 'residualsim']:
            for i in range(self.vq_heads):
                head_indices = indices[:, :, i].reshape(-1)
                usage_count = torch.bincount(head_indices, minlength=self.num_embeddings)
                total_count = head_indices.numel()
                probs = usage_count.float() / total_count
                nonzero_probs = probs[probs > 0]
                perplexity = torch.exp(-torch.sum(nonzero_probs * torch.log(nonzero_probs)))
                perplexity_list.append(perplexity)
            return perplexity_list
        else: # For single-head VQ
            flat_indices = indices.reshape(-1)
            usage_count = torch.bincount(flat_indices, minlength=self.num_embeddings)
            total_count = flat_indices.numel()
            probs = usage_count.float() / total_count
            nonzero_probs = probs[probs > 0]
            perplexity = torch.exp(-torch.sum(nonzero_probs * torch.log(nonzero_probs)))
            return [perplexity]


    def forward(self, x):
        # Input: (batch_size, sequence_length, num_channels) e.g., (B, 5000, 12)
        
        # Encode
        z = self._encoder(x) # z is a dictionary containing encoder outputs
        # z.last_hidden_state shape: (batch_size, num_channels, num_patches, patch_embedding_dim)
        hidden_state = self.BN(z.last_hidden_state)

        # Quantize
        quantization, indices, loss = self._vq(hidden_state) # quantization shape: (B, 12, num_patches, embedding_dim)
        
        # Decode, Map quantized representation to a sequence
        quantized_de = self._map(quantization).reshape(x.shape[0], 12, -1) # -> (B, 12, new_seq_len)
        
        # Padding
        if quantized_de.shape[-1] < self.context_length:
            padding_needed = self.context_length - quantized_de.shape[-1]
            padding = quantized_de[:, :, -1:].repeat(1, 1, padding_needed) # Simple padding
            quantized_de = torch.cat((quantized_de, padding), dim=-1)
        
        quantized_de = self.BN1d(quantized_de)
        
        # Reconstruct signal with S4 decoder
        # Transpose for S4: (B, seq_len, channels)
        x_recon = self._decoder(quantized_de.transpose(-1, -2)) # -> (B, seq_len, decoder_d_model)

        # Project back to original channel dimension
        x_recon = self._projection(x_recon) # -> (B, seq_len, num_channels)
        # De-standardization
        x_recon = x_recon * z.scale + z.loc

        perplexity = self.calculate_perplexity(indices)
        vq_loss = loss.sum() / self.vq_heads if self.vq_heads > 0 else loss.sum()

        return vq_loss, x_recon, quantization, indices, perplexity



class GPTIndexPredictor(nn.Module):
    """
    A GPT-like model to predict future codebook indices autoregressively.
    """
    def __init__(self, codebook_size=1024, d_model=512, nhead=8, 
                 num_layers=6, num_channels=12, num_codebooks=4, 
                 dropout=0.1, use_kv_cache=True):
        super().__init__()
        self.d_model = d_model
        self.num_codebooks = num_codebooks
        self.num_channels = num_channels
        self.use_kv_cache = use_kv_cache
        
        self.embeddings = nn.ModuleList([
            nn.Embedding(codebook_size, d_model) for _ in range(num_codebooks)
        ])

        self.pos_encoder = PositionEmbedding(d_model)
        self.input_norm = nn.LayerNorm(d_model)
        self.output_norm = nn.LayerNorm(d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model, nhead, 
            dim_feedforward=4*d_model,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True      # Pre-norm
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(d_model),
                nn.Linear(d_model, codebook_size)
            ) for _ in range(num_codebooks)
        ])
        # Cache for efficient autoregressive generation
        self.cache = None

    def forward(self, x, mask=None, use_cache=False):
        """
        Forward pass with optional caching for efficient autoregressive generation.
        
        Args:
            x: Input (batch, channels, seq_len, codebooks)
            mask: Optional attention mask
            use_cache: KV cache for generation
        """
        # Input x shape: (batch, channels, seq_len, codebooks)
        batch_size, _, seq_len, _ = x.shape
        
        x = x.view(-1, seq_len, self.num_codebooks)  # -> (batch*channels, seq_len, codebooks)

        embeddings = []
        for i in range(self.num_codebooks):
            emb = self.embeddings[i](x[..., i])  # -> (batch*ch, seq_len, d_model)
            embeddings.append(emb)
        
        # weighted sum
        weights = F.softmax(torch.tensor([1.0] * self.num_codebooks), dim=0)
        embedded = sum(w * emb for w, emb in zip(weights, embeddings))
        
        embedded = self.pos_encoder(embedded) * math.sqrt(self.d_model)
        
        embedded = self.input_norm(embedded)

        if mask is None:
            mask = self.generate_square_subsequent_mask(seq_len).to(x.device)

        if use_cache and self.cache is not None:
            output = self.transformer(
                tgt=embedded[:, -1:], 
                memory=self.cache['memory'],
                tgt_mask=mask[-1:, -1:], 
                memory_mask=mask,
                tgt_key_padding_mask=self.cache.get('key_padding_mask'),
                memory_key_padding_mask=self.cache.get('key_padding_mask')
            )
            # Update cache
            self.cache['memory'] = torch.cat([self.cache['memory'], embedded], dim=1)
        else:
            output = self.transformer(
                tgt=embedded, 
                memory=embedded,
                tgt_mask=mask,
                memory_mask=mask
            )
            if use_cache:
                self.cache = {'memory': embedded}

        output = self.output_norm(output)

        logits = [head(output) for head in self.heads]
        logits = torch.stack(logits, dim=-2)  # -> (batch*ch, seq_len, 4, codebook_size)

        return logits.view(batch_size, self.num_channels, seq_len, self.num_codebooks, -1)

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
        return mask

    def reset_cache(self):
        """Reset the KV cache for new generation sequences."""
        self.cache = None

    def predict(self, input_indices, predict_steps=63, temperature=1.0, top_k=50):
        """
        Autoregressively predict future indices.
        
        Args:
            input_indices: (batch, ch, seq_len, 4)
            predict_steps: Number of steps to predict
            temperature: Sampling temperature
            top_k: Top-k sampling parameter (None = disabled)
        """
        self.eval()
        self.reset_cache()  # Reset cache for new prediction
        
        with torch.no_grad():
            current_seq = input_indices
            predictions = []
            
            for step in range(predict_steps):
                logits = self(current_seq, use_cache=self.use_kv_cache)
                next_logits = logits[:, :, -1] 
                
                if temperature != 1.0:
                    next_logits = next_logits / temperature
                
                # top-k
                if top_k is not None:
                    values, indices = torch.topk(next_logits, top_k, dim=-1)
                    next_logits = torch.full_like(next_logits, float('-inf'))
                    next_logits.scatter_(-1, indices, values)
                
                probs = F.softmax(next_logits, dim=-1)
                next_indices_list = [
                    torch.multinomial(probs[..., i, :].view(-1, probs.size(-1)), 1)
                    for i in range(self.num_codebooks)
                ]
                next_indices = torch.cat(next_indices_list, dim=-1)
                next_indices = next_indices.view(probs.size(0), probs.size(1), self.num_codebooks)
                
                predictions.append(next_indices)
                
                # Append the predicted indices to the sequence
                current_seq = torch.cat([
                    current_seq, 
                    next_indices.unsqueeze(2) # Add a sequence dimension
                ], dim=2)
            
            # Reset cache after prediction
            self.reset_cache()
            
            # Stack predictions
            return torch.stack(predictions, dim=2)  # (batch, ch, predict_steps, 4)


if __name__ == '__main__':
    print('The whole model has been loaded!')
