import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from Models.Attention import *
from einops import rearrange, repeat

def Encoder_factory(config):
    model = SPR(config, num_classes=config['num_labels'])
    return model

class SPR(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        """
         channel_size: number of EEG channels
         seq_len: number of timepoints in a window
        """
        channel_size, seq_len = config['Data_shape'][1], config['Data_shape'][2]
        self.emb_size = config['emb_size']  # d_x
        seq_len = int(seq_len / config['patch_size'])  # T
        self.PositionalEncoding = PositionalEmbedding(1000, self.emb_size)
        self.device = config['device']
        self.patch_size = config['patch_size']
        self.chunk_size = config['chunk_size']
        self.mask_ratio = config['mask_ratio']
        self.sampling_rate = config['sampling_rate']
        self.CLS = nn.Parameter(torch.randn(1, channel_size, 1, self.emb_size))
        self.rep_mask_token = nn.Parameter(torch.randn(1, 1, self.emb_size))

        self.contex_encoder = Encoder(config)
        self.Predictor = Predictor(self.emb_size, config['num_heads'], 1, config['pre_layers'])
        self.out = nn.Linear(self.emb_size, channel_size)

        self.pred_head = nn.Linear(self.emb_size, config['num_labels'])
        self.Norm = nn.LayerNorm(self.emb_size)
        self.Norm2 = nn.LayerNorm(self.emb_size)
        self.attention_pool = AttentionAggregation(self.emb_size, num_heads=config['num_heads'], dropout=0.)
        self.linear_embed = nn.Linear(config['patch_size'], self.emb_size)

    def chunk_masking(self, shape, mask_ratio, device, chunk_size=0):
        B, T = shape
        len_keep = int(T * (1 - mask_ratio))

        if chunk_size == 0:
            num_chunks = 1
            chunk_size = len_keep #// num_chunks
        else:
            num_chunks = len_keep // chunk_size
        
        # Calculate stride to fit all chunks
        stride = T // num_chunks
        
        # Random offset for chunk_size < stride
        max_offset = stride - chunk_size
        if max_offset < 0:
            max_offset = 0
        
        first_chunk_offset = torch.randint(0, max_offset + 1, (B, num_chunks), device=device)
        
        # All chunk start positions
        chunk_indices = torch.arange(num_chunks, device=device)
        start_positions = first_chunk_offset + chunk_indices * stride  # (B, num_chunks)
        
        # Create chunks
        chunk_offsets = torch.arange(chunk_size, device=device)
        all_positions = start_positions.unsqueeze(-1) + chunk_offsets
        all_positions = all_positions.view(B, -1)
        
        # Create mask
        mask = torch.ones([B, T], dtype=bool, device=device)
        mask.scatter_(1, all_positions, False)
        
        return mask

    def linear_prob(self, x):
        with (torch.no_grad()):
            B, C, T = x.shape
            assert T % self.patch_size == 0, f"Time series length should be divisible by patch_size, not {T} % {self.patch_size}"
            T = T // self.patch_size
            x = x.view(B, C, -1, self.patch_size)

            patches = self.linear_embed(x)
            patches = F.gelu(patches)
            patches = self.Norm(patches)
            patches = torch.concat([self.CLS.repeat(B, 1, 1, 1), patches], dim=2)
            patches = patches.view(B*C, T+1, -1)
            patches = patches + self.PositionalEncoding(patches)
            patches = self.Norm2(patches)
            patches = self.contex_encoder(patches)
            out = patches[:, 0, :].view(B, C, -1).mean(1)
            return out

    def pretrain_forward(self, x):
        B, C, T = x.shape
        assert T % self.patch_size == 0, f"Time series length should be divisible by patch_size, not {T} % {self.patch_size}"
        T = T // self.patch_size
        x = x.view(B, C, -1, self.patch_size) # (B, C, T, patch_size)

        patches = self.linear_embed(x)
        patches = F.gelu(patches)
        patches = self.Norm(patches)
        patches = patches.view(B*C, T, -1).contiguous()

        mask = self.chunk_masking(
            shape=(B*C, T),
            mask_ratio=self.mask_ratio,
            device=self.device,
            chunk_size=self.chunk_size
        )
        masked_num = mask.sum() #[bs]
        random_sample = torch.normal(mean=0, std=0.02, size=(masked_num, self.emb_size)).to(x.device)
        patches[mask] = random_sample
        patches = patches + self.PositionalEncoding(patches)
        patches = self.Norm2(patches)

        rep_contex = self.contex_encoder(patches)
        rep_mask_token = self.rep_mask_token.repeat(B*C, T, 1)
        rep_mask_token = rep_mask_token + self.PositionalEncoding(rep_mask_token)
        rep_mask_prediction = self.Predictor(rep_contex, rep_mask_token)
        rep_mask_prediction = self.out(rep_mask_prediction[mask])

        return [rep_mask_prediction, mask]

    def forward(self, x):
        B, C, T = x.shape
        assert T % self.patch_size == 0, f"Time series length should be divisible by patch_size, not {T} % {self.patch_size}"
        T = T // self.patch_size
        x = x.view(B, C, -1, self.patch_size)

        patches = self.linear_embed(x)
        patches = F.gelu(patches)
        patches = self.Norm(patches)
        patches = torch.concat([self.CLS.repeat(B, 1, 1, 1), patches], dim=2)
        patches = patches.view(B*C, T+1, -1)

        patches = patches + self.PositionalEncoding(patches)
        patches = self.Norm2(patches)
        patches = self.contex_encoder(patches)

        out = patches[:, 0, :].view(B, C, -1)
        out = self.attention_pool(out)

        return self.pred_head(out)

class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        d_model = config['emb_size']
        attn_heads = config['num_heads']
        d_ffn = 4 * d_model
        layers = config['layers']
        dropout = config['dropout']
        enable_res_parameter = True
        self.TRMs = nn.ModuleList(
            [TransformerBlock(d_model, attn_heads, d_ffn, enable_res_parameter, dropout) for i in range(layers)])

    def forward(self, x):
        for TRM in self.TRMs:
            x = TRM(x, mask=None)
        return x
    
class Predictor(nn.Module):
    def __init__(self, d_model, attn_heads, enable_res_parameter, layers):
        super(Predictor, self).__init__()
        self.layers = nn.ModuleList(
            [CrossAttnTRMBlock(d_model, attn_heads, 4*d_model, enable_res_parameter) for i in range(layers)])

    def forward(self, rep_visible, x):
        for TRM in self.layers:
            x = TRM(rep_visible, x)
        return x


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)