import torch
import torch.nn as nn

import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
import copy
from einops import repeat, rearrange

import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

class Patch_Embeddings_CNN(nn.Module):
    
    def __init__(self, patch_size, model_dim):
        super(Patch_Embeddings_CNN, self).__init__()
        self.ps = patch_size
        
        self.re1 = Rearrange("b (l ps) -> (b l) ps", ps = patch_size)
        self.strip_2_emb = nn.Sequential( 
                                nn.LayerNorm(patch_size),   
                                nn.Linear(in_features=patch_size, out_features=32, bias=False),
                                nn.LayerNorm(32))
        
        self.cnn = nn.Conv1d(in_channels=32, out_channels=model_dim, kernel_size=5, padding = 2)

    def forward(self, x):
        b, l = x.shape

        h = self.re1(x)
        h = h.unsqueeze(1)
        h = self.strip_2_emb(h).squeeze()
        h = rearrange(h, '(b l) c -> b c l', l = l // self.ps)
        h = self.cnn(h)
        h = rearrange(h, 'b c l -> b l c')
        
        return h


class Patch_Embeddings(nn.Module):
    
    def __init__(self, patch_size, model_dim, in_channels):
        super(Patch_Embeddings, self).__init__()
        self.ps = patch_size
        
        self.re1 = Rearrange("b (l ps) -> (b l) ps", ps = patch_size)
        self.strip_2_emb = nn.Sequential( 
                                nn.LayerNorm(patch_size),   
                                nn.Linear(in_features=patch_size, out_features=model_dim, bias=False),
                                nn.LayerNorm(model_dim))
        

    def forward(self, x):
        b, l = x.shape

        h = self.re1(x)
        h = h.unsqueeze(1)
        h = self.strip_2_emb(h).squeeze()
        h = rearrange(h, '(b l) c -> b l c', l = l // self.ps)
    
        return h
    

class MultiHead_Attention(nn.Module):

    def __init__(self, q_dim, kv_dim, number_heads, do_prob):
        super(MultiHead_Attention, self).__init__()
        self.number_heads = number_heads

        self.scale_factor = 1 / ((kv_dim / number_heads) ** 0.5)
        self.att_drop_out = nn.Dropout(do_prob)
        self.output_drop_out = nn.Dropout(do_prob)

        self.block_output = nn.Linear(kv_dim, kv_dim)

        self.split_head = Rearrange('b l (h d) -> b h l d', h = self.number_heads)
        self.split_head_t = Rearrange('b l (h d) -> b h d l', h = self.number_heads)
        self.concat = Rearrange('b h l d -> b l (h d)') 

        self.x_to_q = nn.Linear(q_dim, kv_dim)
        self.x_to_k = nn.Linear(kv_dim, kv_dim)
        self.x_to_v = nn.Linear(kv_dim, kv_dim)


    def forward(self, q, k, v):
        # q, k and v with shape (batch_size, seq_len, embedding_dimension)
        q = self.split_head(self.x_to_q(q))
        k_transpose = self.split_head_t(self.x_to_k(k))
        v = self.split_head(self.x_to_v(v))

        attention = torch.matmul(q, k_transpose)
        attention = attention * self.scale_factor
        attention = self.att_drop_out(attention.softmax(-1))
        
        output = torch.matmul(attention, v)
        output = self.concat(output)
        output = self.block_output(output)

        return self.output_drop_out(output)



class FeedForwardNet(nn.Module):
    def __init__(self, model_dim, do_prob, wide_factor=4):
        super(FeedForwardNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(model_dim, model_dim * wide_factor),
            nn.GELU(),
            nn.Dropout(do_prob),
            nn.Linear(model_dim * wide_factor, model_dim),
            nn.Dropout(do_prob)
        )

    def forward(self, x):
        return self.net(x)



class Add_and_Norm(nn.Module):
    
    def __init__(self, model_dim):
        super(Add_and_Norm, self).__init__()
        self.norm = nn.LayerNorm(model_dim)

    def forward(self, x, res):
        return self.norm(x + res)


class EncoderBlock(nn.Module):
    def __init__(self, number_heads, model_dim, do_prob):
        super().__init__()
        self.mh_atten_block = MultiHead_Attention(number_heads=number_heads, 
                                                  q_dim = model_dim,
                                                  kv_dim = model_dim,
                                                  do_prob=do_prob)
        
        self.add_norm_mh = Add_and_Norm(model_dim=model_dim)
        self.ffn = FeedForwardNet(model_dim=model_dim, 
                                  do_prob=do_prob)

        self.add_norm_ffn = Add_and_Norm(model_dim=model_dim)

    def forward(self, x):
        res = x
        h = self.mh_atten_block(x, x, x)
        h = self.add_norm_mh(h, res)
        
        res = h
        h = self.ffn(h)
        return self.add_norm_ffn(h, res)



class Encoder(nn.Module):
    def __init__(self, num_blocks, num_heads, model_dim, do_prob):
        super(Encoder, self).__init__()
        self.num_blocks = num_blocks
        self.num_heads  = num_heads
        self.model_dim = model_dim
        self.do_prob = do_prob

        self.net = self.create_net()

    def forward(self, x):
        h = x

        for layer in self.net:
            h = layer(h)

        return h

    def create_net(self):
        net = nn.ModuleList()

        for _ in range(self.num_blocks):
            net.append(EncoderBlock(
                number_heads=self.num_heads,
                model_dim=self.model_dim,
                do_prob=self.do_prob))

        return net



class Self_DecoderBlock(nn.Module):
    def __init__(self, number_heads, enc_dim, dec_dim, do_prob):
        super().__init__()
        
        self.self_atten_block = MultiHead_Attention(number_heads=number_heads, 
                                            q_dim = dec_dim,
                                            kv_dim = dec_dim,
                                            do_prob=do_prob)
        
        self.self_add_norm = Add_and_Norm(model_dim=dec_dim)

        # Final Fullt Connected Layer
        self.ffn = FeedForwardNet(model_dim=dec_dim, 
                                  do_prob=do_prob)

        self.add_norm_ffn = Add_and_Norm(model_dim=dec_dim)


    def forward(self, x1):
        # X1 encoding original frame
        res = x1    
        h = self.self_atten_block(x1, x1, x1)
        h = self.self_add_norm(h, res)

        res = h
        h = self.ffn(h)
        return self.add_norm_ffn(h, res)



class Self_Decoder(nn.Module):
    def __init__(self, num_blocks, num_heads, enc_dim, dec_dim, do_prob):
        super(Self_Decoder, self).__init__()
        self.num_blocks = num_blocks
        self.num_heads  = num_heads
        self.enc_dim = enc_dim
        self.dec_dim = dec_dim
        self.do_prob = do_prob

        self.decoder_embed = nn.Linear(enc_dim, dec_dim)
        self.net = self.create_net()

    def forward(self, x1):
        h = self.decoder_embed(x1)

        for layer in self.net:
            h = layer(h)

        return h

    def create_net(self):
        net = nn.ModuleList()

        for _ in range(self.num_blocks):
            net.append(Self_DecoderBlock(
                number_heads=self.num_heads,
                enc_dim=self.enc_dim,
                dec_dim=self.dec_dim,
                do_prob=self.do_prob))

        return net



class Features_Projector(nn.Module):
    def __init__(self, c_input, hidden, c_output):
        super(Features_Projector, self).__init__()
        self.net = nn.Sequential(
                nn.Linear(in_features=c_input, out_features=hidden),
                nn.BatchNorm1d(num_features=hidden),
                nn.ReLU(),
                nn.Linear(in_features=hidden, out_features=c_output))
        
    def forward(self, x):
        return self.net(x)

class Jepa(nn.Module):
    def __init__(self, enc_num_blocks, num_heads, model_dim, do_prob, patch_size, in_channels, fs, l, dec_num_blocks, mask_ratio, dec_dim):
        
        super(Jepa, self).__init__()
        self.model_dim = model_dim
        self.patch_size = patch_size

        self.enc_class_token = nn.Parameter(torch.randn(1, 1, model_dim), requires_grad=True)
        torch.nn.init.normal_(self.enc_class_token, std=.02)

        self.enc_patch_embedding = Patch_Embeddings(patch_size=patch_size, model_dim=model_dim, in_channels=in_channels)
    
        self.pos_enc = self.get_2d_embeddings(max_seq_len = int(fs * l / patch_size + 1))
        self.pos_enc = self.pos_enc.requires_grad_(False).to(device)
        
        self.encoder = Encoder(num_blocks=enc_num_blocks, num_heads=num_heads, model_dim=model_dim, do_prob=do_prob)
        self.predictor = Self_Decoder(num_blocks=dec_num_blocks, num_heads=num_heads, enc_dim = model_dim, dec_dim = dec_dim, do_prob=do_prob)
    
        self.mask_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        torch.nn.init.normal_(self.mask_token, std=.02)
        
        self.mask_ratio = mask_ratio
        self.decoder_pred = nn.Linear(dec_dim, model_dim, bias=False) 
        
        self.initialize_weights()

        self.teacher_encoder = copy.deepcopy(self.encoder).requires_grad_(False)
        self.teacher_patch_embedding = copy.deepcopy(self.enc_patch_embedding).requires_grad_(False)
        self.teacher_pos_enc = copy.deepcopy(self.pos_enc).requires_grad_(False)
        self.teacher_class_token = copy.deepcopy(self.enc_class_token).requires_grad_(False)
        
        self.tau = 0.995
        
  
    def forward(self, x1):     

        h1 = self.enc_patch_embedding(x1)
        h1 = h1 + self.pos_enc[:, 1:, :]      

        h1_masked, mask, ids_restore = self.random_masking(h1)
        
        enc_cls_token_1 = repeat(self.enc_class_token, '() p d -> b p d', b=h1_masked.shape[0]) + self.pos_enc[:, 0, :]

        h1_masked = torch.cat([enc_cls_token_1, h1_masked], dim=1)
        h1 = self.encoder(h1_masked)

        h1 = self.restore_input(h1, ids_restore, mask)
        h1 += self.pos_enc

        h = self.predictor(h1)[:, 1:, :]
        h = self.decoder_pred(h)
       
        with torch.no_grad():
            teacher_cls_token = repeat(self.teacher_class_token, '() p d -> b p d', b=x1.shape[0])
            teacher_h  = self.teacher_patch_embedding(x1)   
            
            teacher_h = torch.cat([teacher_cls_token, teacher_h], dim=1)
            teacher_h += self.teacher_pos_enc

            gt  = self.teacher_encoder(teacher_h)
            gt = F.layer_norm(gt, (gt.size(-1),))
            gt = gt[:, 1:, :]

        
        
        rec_loss = self.reconstruction_loss(h, gt, mask)
        
        return  rec_loss
         
    
    def get_representations(self, x):

        enc_cls_token = repeat(self.enc_class_token, '() p d -> b p d', b=x.shape[0])
        h1 = self.enc_patch_embedding(x)        
        h1 = torch.cat([enc_cls_token, h1], dim=1)
        
        h1 = h1 + self.pos_enc
        h1 = self.encoder(h1)

        return h1[:, 0, :]

    
    def reconstruction_loss(self, prediction, gt, mask):
        
        loss = torch.abs(prediction - gt)
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches

        return loss

    def random_masking(self, x):
        
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - self.mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    
    def get_2d_embeddings(self, max_seq_len):   

        position_id = torch.arange(0, max_seq_len).unsqueeze(1)
        frequencies = torch.pow(10000., -torch.arange(0, self.model_dim, 2, dtype=torch.float) / self.model_dim)
        positional_encodings = torch.zeros(max_seq_len, self.model_dim)
        positional_encodings[:, 0::2] = torch.sin(position_id * frequencies)  
        positional_encodings[:, 1::2] = torch.cos(position_id * frequencies)  

        return positional_encodings.unsqueeze(0)



    def restore_input(self, x_masked, ids_restore, mask):
        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x_masked.shape[0], ids_restore.shape[1] + 1 - x_masked.shape[1], 1)
        x_ = torch.cat([x_masked[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_masked.shape[2]))  # unshuffle
        
        # x_ += self.pos_enc[:, 1:, :] * mask.unsqueeze(-1)
        

        x = torch.cat([x_masked[:, :1, :], x_], dim=1)  # append cls token

        return x
    
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.no_grad()
    def update_moving_average(self):
        for online, target in zip(self.encoder.parameters(), self.teacher_encoder.parameters()):
            target.data = self.tau * target.data + (1 - self.tau) * online.data   
    
        for online, target in zip(self.enc_class_token, self.teacher_class_token):
            target.data = self.tau * target.data + (1 - self.tau) * online.data   

        for online, target in zip(self.enc_patch_embedding.parameters(), self.teacher_patch_embedding.parameters()):
            target.data = self.tau * target.data + (1 - self.tau) * online.data   

        for online, target in zip(self.pos_enc, self.teacher_pos_enc):
            target.data = self.tau * target.data + (1 - self.tau) * online.data


