import torch
import torch.nn as nn
import torch.nn.functional as F
from lgdea.models.vision_model import ImageEncoder
from lgdea.models.text_model import BertEncoder
from lgdea.models.prototype_model import DiagnosticSemanticPrototypeModel


class Stage2LesionSemanticModel(nn.Module):


    def __init__(self, cfg, stage1_checkpoint_path=None):
        super(Stage2LesionSemanticModel, self).__init__()

        self.embedding_dim = cfg.model.text.embedding_dim
        self.num_prototypes = getattr(cfg.model, 'num_prototypes', 64)
        self.temperature = getattr(cfg.model, 'temperature', 0.1)
        self.consistency_k = getattr(cfg.model, 'consistency_k', 5)
        self.vision_encoder = ImageEncoder(cfg)
        vision_embed_dim = self.vision_encoder.embed_dim

        if stage1_checkpoint_path:
            stage1_model = DiagnosticSemanticPrototypeModel(cfg)
            checkpoint = torch.load(stage1_checkpoint_path, map_location='cpu')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                stage1_model.load_state_dict(new_state_dict, strict=False)
            else:
                stage1_model.load_state_dict(checkpoint, strict=False)

            self.text_encoder = stage1_model.text_encoder
            self.prototypes = stage1_model.prototypes
            self.text_projection = stage1_model.projection
        else:
            self.text_encoder = BertEncoder(cfg)
            self.prototypes = nn.Parameter(
                torch.randn(self.num_prototypes, self.embedding_dim)
            )
            nn.init.xavier_uniform_(self.prototypes)
            self.text_projection = nn.Linear(self.embedding_dim, self.embedding_dim)

        num_queries = getattr(cfg.model, 'num_queries', 32)  # L个查询
        if num_queries is None:
            num_queries = 32
        self.num_queries = num_queries
        self.queries = nn.Parameter(
            torch.randn(self.num_queries, vision_embed_dim)
        )
        nn.init.xavier_uniform_(self.queries)

        num_attention_heads = getattr(cfg.model, 'num_attention_heads', 8)
        if num_attention_heads is None:
            num_attention_heads = 8
        self.query_attention = nn.MultiheadAttention(
            embed_dim=vision_embed_dim,
            num_heads=num_attention_heads,
            batch_first=True
        )

        self.vision_projection = nn.Linear(vision_embed_dim, self.embedding_dim)

        self.prototype_head = nn.Linear(self.embedding_dim, self.num_prototypes)

    def forward(self, images, input_ids, attention_mask, token_type_ids, num_evidences):

        B = images.shape[0]

        global_feat, feat_4, feat_8, feat_12 = self.vision_encoder.vit_forward(images)


        queries_expanded = self.queries.unsqueeze(0).expand(B, -1, -1)  # [B, L, vision_embed_dim]

        leis, _ = self.query_attention(
            query=queries_expanded,
            key=feat_8,
            value=feat_8
        )

        leis_projected = self.vision_projection(leis)

        leis_logits = self.prototype_head(leis_projected)  # [B, L, K]
        image_prototype_dist = F.softmax(leis_logits / self.temperature, dim=-1)  # [B, L, K]

        image_level_dist = image_prototype_dist.mean(dim=1)  # [B, K]

        text_prototype_dist = self._compute_text_prototype_distribution(
            input_ids, attention_mask, token_type_ids, num_evidences
        )  # [B, K]

        return {
            'leis': leis_projected,  # [B, L, embedding_dim]
            'image_prototype_dist': image_prototype_dist,  # [B, L, K]
            'image_level_dist': image_level_dist,  # [B, K]
            'text_prototype_dist': text_prototype_dist  # [B, K]
        }

    def _compute_text_prototype_distribution(self, input_ids, attention_mask, token_type_ids, num_evidences):
        B, max_evidences, max_length = input_ids.shape

        input_ids_flat = input_ids.view(-1, max_length)
        attention_mask_flat = attention_mask.view(-1, max_length)
        token_type_ids_flat = token_type_ids.view(-1, max_length)

        _, evidence_embeddings_flat, _ = self.text_encoder(
            ids=input_ids_flat,
            attn_mask=attention_mask_flat,
            token_type=token_type_ids_flat
        )

        evidence_embeddings_flat = self.text_projection(evidence_embeddings_flat)
        evidence_embeddings = evidence_embeddings_flat.view(B, max_evidences, self.embedding_dim)

        evidence_norm = F.normalize(evidence_embeddings, p=2, dim=-1)
        prototype_norm = F.normalize(self.prototypes, p=2, dim=-1)

        similarity = torch.matmul(evidence_norm, prototype_norm.t()) / self.temperature
        prototype_assignments = F.softmax(similarity, dim=-1)  # [B, max_evidences, K]

        mask = torch.arange(max_evidences, device=num_evidences.device).unsqueeze(0) < num_evidences.unsqueeze(1)
        mask = mask.float().unsqueeze(-1)

        report_prototype_dist = (prototype_assignments * mask).sum(dim=1) / (num_evidences.float().unsqueeze(-1) + 1e-8)
        return report_prototype_dist  # [B, K]

    def compute_distillation_loss(self, image_prototype_dist, text_prototype_dist, num_evidences):

        image_level_dist = image_prototype_dist.mean(dim=1)  # [B, K]

        log_image_dist = F.log_softmax(image_level_dist / self.temperature, dim=-1)
        text_dist = F.softmax(text_prototype_dist / self.temperature, dim=-1)

        kl_loss = F.kl_div(log_image_dist, text_dist, reduction='batchmean')

        return kl_loss

    def compute_lesion_consistency_loss(self, leis, image_prototype_dist):

        B, L, embedding_dim = leis.shape
        _, _, K = image_prototype_dist.shape

        leis_flat = leis.view(-1, embedding_dim)

        leis_norm = F.normalize(leis_flat, p=2, dim=-1)

        similarity_matrix = torch.matmul(leis_norm, leis_norm.t())

        mask = torch.eye(B * L, device=leis.device, dtype=torch.bool)
        similarity_matrix = similarity_matrix.masked_fill(mask, float('-inf'))

        topk_similarities, topk_indices = torch.topk(similarity_matrix, k=min(self.consistency_k, B * L - 1),
                                                     dim=-1)  # [B*L, k]

        neighbor_weights = F.softmax(topk_similarities / self.temperature, dim=-1)  # [B*L, k]

        prototype_dist_flat = image_prototype_dist.view(-1, K)  # [B*L, K]


        neighbor_prototype_dists = prototype_dist_flat[topk_indices].detach()  # [B*L, k, K]


        eps = 1e-8
        prototype_dist_flat_safe = prototype_dist_flat + eps
        neighbor_prototype_dists_safe = neighbor_prototype_dists + eps

        prototype_dist_flat_safe = prototype_dist_flat_safe / prototype_dist_flat_safe.sum(dim=-1, keepdim=True)
        neighbor_prototype_dists_safe = neighbor_prototype_dists_safe / neighbor_prototype_dists_safe.sum(dim=-1,
                                                                                                          keepdim=True)

        log_current_dist = torch.log(prototype_dist_flat_safe)

        log_current_dist_expanded = log_current_dist.unsqueeze(1).expand(-1, neighbor_prototype_dists_safe.size(1),
                                                                         -1)  # [B*L, k, K]
        kl_per_neighbor = F.kl_div(
            log_current_dist_expanded,
            neighbor_prototype_dists_safe,
            reduction='none'
        ).sum(dim=-1)  # [B*L, k]

        weighted_kl = (kl_per_neighbor * neighbor_weights).sum(dim=-1)  # [B*L]

        consistency_loss = weighted_kl.mean()

        return consistency_loss
