from typing import List

import numpy as np
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import optim

from ._utils import complete_masking
from ._visiumformer_spatial import (
    CosineWarmupScheduler,
    VisiumformerSpatial,
)

CLS_TOKEN = 2

class Adapter(nn.Module):
    def __init__(self, c_in, reduction=2):
        super(Adapter, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(c_in, c_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c_in // reduction, c_in, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.fc(x)
        return x

class FinetuneStamp(pl.LightningModule):
    def __init__(self, 
                 spot_config: dict,
                 visual_config: dict,
                 dim_output: int,
                 temperature: float,
                 extract_layers: List[int],
                 function_layers: str,
                 lr: float, 
                 warmup: int, 
                 max_epochs: int,
                 pool: int = 'mean',
                 without_context: bool = True,
                 margin: float = 0.5,
                 p: int = 2,
                 eps: float = 1e-6,
                 ):
        """
        Args:
            backbone (pl.LightningModule): pretrained model
            baseline (bool): just for wandb logger to know it's baseline; baseline here means non-trained Transformer
            extract_layers (int): which hidden representations use as input for the linear layer
            function_layers (str): which function use to combine the hidden representations used
            lr (float): learning rate
            warmup (int): number of steps that the warmup takes
            max_epochs (int): number of steps until the learning rate reaches 0
            pool (str): could be None, 'cls' or 'mean'. CLS adds a token that gathers info of the sequence, mean just averages all tokens

        """
        super().__init__()
        self.spot_backbone = VisiumformerSpatial(dim_model=spot_config['dim_model'], 
                                                nheads=spot_config['nheads'], 
                                                dim_feedforward=spot_config['dim_feedforward'], 
                                                nlayers=spot_config['nlayers'],
                                                dropout=spot_config['dropout'],
                                                batch_first=spot_config['batch_first'], 
                                                n_tokens=spot_config['n_tokens'],
                                                context_length=spot_config['context_length'],
                                                autoregressive=spot_config['autoregressive'],
                                                pool=spot_config['pool'],
                                                learnable_pe=spot_config['learnable_pe'],
                                                spatial_aware=spot_config['spatial_aware'],
                                                masking_p=0.0)
        
        self.spot_backbone.hparams.masking_p = 0.0
        self.spot_projection = nn.Linear(self.spot_backbone.hparams.dim_model, dim_output)

        if spot_config['pretrained_path'] is not None:
            checkpoint = torch.load(spot_config['pretrained_path'], map_location='cpu')
            model_state_dict = self.spot_backbone.state_dict()
            # 创建一个新的 state_dict，仅保留尺寸匹配的参数
            filtered_state_dict = {
                k: v for k, v in checkpoint['state_dict'].items()
                if k in model_state_dict and model_state_dict[k].shape == v.shape
            }
            # 加载过滤后的 state_dict
            self.spot_backbone.load_state_dict(filtered_state_dict, strict=False)
            print("Did not load the following keys:", set(model_state_dict.keys()) - set(filtered_state_dict.keys()))

        visual_model = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True)
        visual_model.load_state_dict(torch.load(visual_config['pretrained_path'], map_location="cpu"), strict=True)
        from models._prompt_learner import VisionPromptLearnerUni
        self.visual_backbone = VisionPromptLearnerUni(visual_model)
        self.patch_projection = nn.Linear(visual_model.embed_dim, dim_output)
        self.patch_adapter = Adapter(visual_model.embed_dim, 4)

        
        self.spot_adapter = Adapter(self.spot_backbone.hparams.dim_model, 4)

        self.visual_backbone.train()
        self.visual_backbone_name = visual_config['model_name']
        
        self.temperature = temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.save_hyperparameters(ignore=['backbone'])
        self.freeze_backbone()
    
    def freeze_backbone(self):
        print("Freezing backbone")
        for param in self.spot_backbone.parameters():  
            param.requires_grad = False  
        for param in self.visual_backbone.parameters():  
            param.requires_grad = False  

        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params}")
        print(f"Trainable parameters: {trainable_params}")
        print(f"Trainable layers including: {[name for name, param in self.named_parameters() if param.requires_grad]}")
        self.frozen_keys = [name for name, param in self.named_parameters() if not param.requires_grad]
        
    def encode_gene(self, batch):
        # x -> size: batch x (context_length) x 1
        batch = complete_masking(batch, 0.0, self.spot_backbone.hparams.n_tokens+5)
        masked_indices = batch['masked_indices'].to(self.spot_backbone.device)
        attention_mask = batch['attention_mask'].to(self.spot_backbone.device)
        token_embedding = self.spot_backbone.embeddings(masked_indices)

        if self.spot_backbone.hparams.learnable_pe:
            pos_embedding = self.spot_backbone.positional_embedding(self.spot_backbone.pos.to(token_embedding.device))
            embeddings = self.spot_backbone.dropout(token_embedding + pos_embedding)
        else:
            embeddings = self.spot_backbone.positional_embedding(token_embedding)

        hidden_repr = []

        for i in range(len(self.spot_backbone.encoder.layers)):
            layer = self.spot_backbone.encoder.layers[i]
            embeddings = layer(embeddings, is_causal=self.spot_backbone.autoregressive, src_key_padding_mask=attention_mask) # bs x seq_len x dim
            if i in self.hparams.extract_layers:
                hidden_repr.append(embeddings)

        if self.hparams.function_layers == "mean":
            combined_tensor = torch.stack(hidden_repr, dim=-1)
            transformer_output = torch.mean(combined_tensor, dim=-1)  # bs x seq_len x dim
        if self.hparams.function_layers == "sum":
            combined_tensor = torch.stack(hidden_repr, dim=-1)
            transformer_output = torch.sum(combined_tensor, dim=-1)  # bs x seq_len x dim
        if self.hparams.function_layers == "concat":
            transformer_output = torch.cat(hidden_repr, dim=2)
                        

        if self.hparams.without_context:
            cls_prediction = transformer_output[:, 3:, :].mean(1)
        else:
            cls_prediction = transformer_output.mean(1)

        return cls_prediction
            
    def encode_visual(self, batch):
        # x -> size: batch x (context_length) x 1
        image = batch['images']
        
        patch_features, _ = self.visual_backbone(image)
        
        return patch_features
    
    def forward(self, batch):
        spot_features = self.encode_gene(batch)
        patch_features = self.encode_visual(batch)

        spot_features_adapter = self.spot_adapter(spot_features)
        patch_features_adapter = self.patch_adapter(patch_features)

        spot_features = 0.5*spot_features + 0.5*spot_features_adapter
        patch_features = 0.5*patch_features + 0.5*patch_features_adapter

        spot_embeddings = self.spot_projection(spot_features)
        patch_embeddings = self.patch_projection(patch_features)

        return spot_embeddings, patch_embeddings
    
    def training_step(self, batch, *args, **kwargs):
        # get the embeddings & features [embeddings: after projection, features: before projection]
        spot_embeddings,  patch_embeddings = self.forward(batch)
        # normalized features
        spot_embeddings = F.normalize(spot_embeddings, dim=-1)
        patch_embeddings = F.normalize(patch_embeddings, dim=-1)


        # cosine similarity as logits, patch <-> spot
        logit_scale = self.logit_scale.exp()
        logits_per_patch = logit_scale * patch_embeddings @ spot_embeddings.t()
        logits_per_spot = logits_per_patch.t()
        labels1 = torch.arange(logits_per_patch.shape[0], device=self.device, dtype=torch.long)
        loss = (
            F.cross_entropy(logits_per_patch, labels1) +
            F.cross_entropy(logits_per_spot, labels1)
        ) / 2
        
        self.log('train_loss', loss.mean(), sync_dist=True, prog_bar=True, reduce_fx='mean')

        return loss.mean()
    
    def validation_step(self, batch, *args, **kwargs):
        # get the embeddings & features [embeddings: after projection, features: before projection]
        spot_embeddings,  patch_embeddings = self.forward(batch)
        # normalized features
        spot_embeddings = F.normalize(spot_embeddings, dim=-1)
        patch_embeddings = F.normalize(patch_embeddings, dim=-1)


        # cosine similarity as logits, patch <-> spot
        logit_scale = self.logit_scale.exp()
        logits_per_patch = logit_scale * patch_embeddings @ spot_embeddings.t()
        logits_per_spot = logits_per_patch.t()
        labels1 = torch.arange(logits_per_patch.shape[0], device=self.device, dtype=torch.long)
        loss = (
            F.cross_entropy(logits_per_patch, labels1) +
            F.cross_entropy(logits_per_spot, labels1)
        ) / 2
        
        self.log('val_loss', loss.mean(), sync_dist=True, prog_bar=True, reduce_fx='mean')

        return loss.mean()
    
    def on_after_batch_transfer(self, batch, dataloader_idx: int):
        
        data_key = 'tokenized_gene'

        if self.hparams.pool == 'cls': # Add cls token at the beginning of the set
            x = batch[data_key]
            cls = torch.ones((x.shape[0], 1), dtype=torch.int32, device=x.device)*CLS_TOKEN # CLS token is index 2
            x = torch.cat((cls, x), dim=1) # add CLS
            batch[data_key] = x

        batch['tokenized_gene'] = batch['tokenized_gene'][:, :self.spot_backbone.hparams.context_length]
        
        return batch
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()),
                                lr=self.hparams.lr,
                                weight_decay=0.001)
        lr_scheduler = CosineWarmupScheduler(optimizer,
                                            warmup=self.hparams.warmup,
                                            max_epochs=self.hparams.max_epochs)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

        
    def initialize_weights(self):

        for name, param in self.named_parameters():
            if 'weight' in name:
                init.normal_(param, mean=0.0, std=0.02)
    
    def cross_entropy(preds, targets, reduction='none'):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == "none":
            return loss
        elif reduction == "mean":
            return loss.mean()
        
    def on_save_checkpoint(self, checkpoint):
        for key in self.frozen_keys:
            checkpoint['state_dict'].pop(key, None)
        return checkpoint