


import torch
import torch.nn as nn
import torch.nn.functional as F
from lgdea.models.prototype_model import DiagnosticSemanticPrototypeModel
from lgdea.models.stage2_model import Stage2LesionSemanticModel


class Stage3RelationPropagationModel(nn.Module):


    def __init__(self, cfg, stage1_checkpoint_path=None, stage2_checkpoint_path=None):
        super(Stage3RelationPropagationModel, self).__init__()

        self.embedding_dim = cfg.model.text.embedding_dim
        self.num_prototypes = getattr(cfg.model, 'num_prototypes', 64)
        self.temperature = getattr(cfg.model, 'temperature', 0.1)
        self.propagation_alpha = getattr(cfg.model, 'propagation_alpha', 0.5)
        self.propagation_steps = getattr(cfg.model, 'propagation_steps', 2)
        self.info_nce_temperature = getattr(cfg.model, 'info_nce_temperature', 0.07)
        self.normalize_entropy = getattr(cfg.model, 'normalize_entropy', True)

        if stage1_checkpoint_path:
            stage1_model = DiagnosticSemanticPrototypeModel(cfg)
            checkpoint = torch.load(stage1_checkpoint_path, map_location='cpu')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                stage1_model.load_state_dict(new_state_dict, strict=False)
            else:
                stage1_model.load_state_dict(checkpoint, strict=False)

            self.text_encoder = stage1_model.text_encoder
            self.prototypes = stage1_model.prototypes
            self.text_projection = stage1_model.projection

            freeze_bert = getattr(cfg.model.text, 'freeze_bert', False)
            if not freeze_bert:
                for param in self.text_encoder.parameters():
                    param.requires_grad = True
                print("Stage3: Text encoder is trainable")
            else:
                print("Stage3: Text encoder is frozen")
        else:
            raise ValueError("stage1_checkpoint_path must be provided")

        if stage2_checkpoint_path:

            checkpoint = torch.load(stage2_checkpoint_path, map_location='cpu')
            num_queries_from_ckpt = None

            if 'hyper_parameters' in checkpoint:
                hyper_params = checkpoint['hyper_parameters']
                if isinstance(hyper_params, dict):
                    if 'model' in hyper_params and isinstance(hyper_params['model'], dict):
                        if 'num_queries' in hyper_params['model']:
                            num_queries_from_ckpt = hyper_params['model']['num_queries']
                    elif 'num_queries' in hyper_params:
                        num_queries_from_ckpt = hyper_params['num_queries']

            if num_queries_from_ckpt is None:
                state_dict = checkpoint.get('state_dict', checkpoint)
                for key in state_dict.keys():
                    if 'queries' in key and not key.endswith('.num_queries'):
                        queries_shape = state_dict[key].shape
                        if len(queries_shape) == 2:
                            num_queries_from_ckpt = queries_shape[0]
                            print(f"Inferred num_queries={num_queries_from_ckpt} from checkpoint queries shape")
                            break

            if num_queries_from_ckpt is not None:
                original_num_queries = getattr(cfg.model, 'num_queries', None)
                cfg.model.num_queries = num_queries_from_ckpt
                print(f"Using num_queries={num_queries_from_ckpt} from Stage2 checkpoint")
            else:
                num_queries = getattr(cfg.model, 'num_queries', 8)
                if num_queries is None:
                    num_queries = 8
                print(f"Using num_queries={num_queries} from config (could not read from checkpoint)")

            stage2_model = Stage2LesionSemanticModel(cfg, stage1_checkpoint_path=stage1_checkpoint_path)

            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                stage2_model.load_state_dict(new_state_dict, strict=False)
            else:
                stage2_model.load_state_dict(checkpoint, strict=False)

            self.vision_encoder = stage2_model.vision_encoder
            self.vision_projection = stage2_model.vision_projection
            self.prototype_head = stage2_model.prototype_head
            self.num_queries = stage2_model.num_queries
            self.queries = stage2_model.queries
            self.query_attention = stage2_model.query_attention

            freeze_cnn = getattr(cfg.model.vision, 'freeze_cnn', False)
            if not freeze_cnn:
                # 如果配置中freeze_cnn为False，允许训练图像编码器
                for param in self.vision_encoder.parameters():
                    param.requires_grad = True
                print("Stage3: Vision encoder is trainable")
            else:
                print("Stage3: Vision encoder is frozen")
        else:
            raise ValueError("stage2_checkpoint_path must be provided")


    def forward(self, images, input_ids, attention_mask, token_type_ids, num_evidences, paired_matrix):

        NI = images.shape[0] if images is not None else 0
        NT = input_ids.shape[0] if input_ids is not None else 0

        text_prototype_dist = None
        if NT > 0:
            text_prototype_dist = self._compute_text_prototype_distribution(
                input_ids, attention_mask, token_type_ids, num_evidences
            )  # [NT, K]

        image_prototype_dist = None
        if NI > 0:
            image_prototype_dist = self._compute_image_prototype_distribution(images)  # [NI, K]

        if NI > 0 and NT > 0:
            ATT = self._build_adjacency_matrix(text_prototype_dist)  # [NT, NT]
            AII = self._build_adjacency_matrix(image_prototype_dist)  # [NI, NI]
            ST = self._degree_normalize(ATT)  # [NT, NT]
            SI = self._degree_normalize(AII)  # [NI, NI]
            relation_matrix = self._relation_propagation(
                paired_matrix, SI, ST, self.propagation_alpha, self.propagation_steps
            )  # [NI, NT]
        else:
            relation_matrix = paired_matrix
        image_global_emb = None
        if NI > 0:
            image_global_emb = self._compute_image_global_embedding(images, image_prototype_dist)

        text_global_emb = None
        if NT > 0:
            text_global_emb = self._compute_text_global_embedding(
                input_ids, attention_mask, token_type_ids, num_evidences, text_prototype_dist
            )

        return {
            'image_prototype_dist': image_prototype_dist,
            'text_prototype_dist': text_prototype_dist,
            'relation_matrix': relation_matrix,
            'image_global_emb': image_global_emb,
            'text_global_emb': text_global_emb
        }

    def _compute_text_prototype_distribution(self, input_ids, attention_mask, token_type_ids, num_evidences):
        NT, max_evidences, max_length = input_ids.shape

        input_ids_flat = input_ids.view(-1, max_length)
        attention_mask_flat = attention_mask.view(-1, max_length)
        token_type_ids_flat = token_type_ids.view(-1, max_length)

        _, evidence_embeddings_flat, _ = self.text_encoder(
            ids=input_ids_flat,
            attn_mask=attention_mask_flat,
            token_type=token_type_ids_flat
        )

        evidence_embeddings_flat = self.text_projection(evidence_embeddings_flat)
        evidence_embeddings = evidence_embeddings_flat.view(NT, max_evidences, self.embedding_dim)

        evidence_norm = F.normalize(evidence_embeddings, p=2, dim=-1)
        prototype_norm = F.normalize(self.prototypes, p=2, dim=-1)

        similarity = torch.matmul(evidence_norm, prototype_norm.t()) / self.temperature
        prototype_assignments = F.softmax(similarity, dim=-1)  # [NT, max_evidences, K]

        mask = torch.arange(max_evidences, device=num_evidences.device).unsqueeze(0) < num_evidences.unsqueeze(1)
        mask = mask.float().unsqueeze(-1)

        report_prototype_dist = (prototype_assignments * mask).sum(dim=1) / (num_evidences.float().unsqueeze(-1) + 1e-8)
        return report_prototype_dist  # [NT, K]

    def _compute_image_prototype_distribution(self, images):
        NI = images.shape[0]
        global_feat, feat_4, feat_8, patch_features = self.vision_encoder.vit_forward(images)

        queries_expanded = self.queries.unsqueeze(0).expand(NI, -1, -1)  # [NI, L, vision_embed_dim]

        leis, _ = self.query_attention(
            query=queries_expanded,
            key=patch_features,
            value=patch_features
        )

        leis_projected = self.vision_projection(leis)
        leis_logits = self.prototype_head(leis_projected)  # [NI, L, K]
        leis_prototype_dist = F.softmax(leis_logits / self.temperature, dim=-1)  # [NI, L, K]
        image_prototype_dist = leis_prototype_dist.mean(dim=1)  # [NI, K]

        return image_prototype_dist

    def _build_adjacency_matrix(self, prototype_dist):
        dist_norm = F.normalize(prototype_dist, p=2, dim=-1)
        adjacency = torch.matmul(dist_norm, dist_norm.t())  # [N, N]

        adjacency = adjacency.clamp(min=0)
        return adjacency

    def _degree_normalize(self, adjacency):
        degree = adjacency.sum(dim=-1, keepdim=True)  # [N, 1]
        degree = degree + 1e-8  # 避免除零
        normalized = adjacency / degree  # [N, N]

        return normalized

    def _relation_propagation(self, Y, SI, ST, alpha, steps):
        P = Y.clone()  # P(0) = Y

        for step in range(steps):
            P = alpha * torch.matmul(torch.matmul(SI, P), ST) + (1 - alpha) * Y

        return P

    def _compute_image_global_embedding(self, images, image_prototype_dist):
        NI = images.shape[0]
        global_feat, feat_4, feat_8, feat_12 = self.vision_encoder.vit_forward(images)

        queries_expanded = self.queries.unsqueeze(0).expand(NI, -1, -1)  # [NI, L, vision_embed_dim]

        leis, _ = self.query_attention(
            query=queries_expanded,
            key=feat_8,
            value=feat_8
        )

        leis_projected = self.vision_projection(leis)  # [NI, L, embedding_dim]

        image_global_emb = leis_projected.mean(dim=1)  # [NI, embedding_dim]

        return image_global_emb

    def _compute_text_global_embedding(self, input_ids, attention_mask, token_type_ids, num_evidences,
                                       text_prototype_dist):

        NT, max_evidences, max_length = input_ids.shape

        input_ids_flat = input_ids.view(-1, max_length)
        attention_mask_flat = attention_mask.view(-1, max_length)
        token_type_ids_flat = token_type_ids.view(-1, max_length)

        _, evidence_embeddings_flat, _ = self.text_encoder(
            ids=input_ids_flat,
            attn_mask=attention_mask_flat,
            token_type=token_type_ids_flat
        )

        evidence_embeddings_flat = self.text_projection(evidence_embeddings_flat)
        evidence_embeddings = evidence_embeddings_flat.view(NT, max_evidences, self.embedding_dim)

        mask = torch.arange(max_evidences, device=num_evidences.device).unsqueeze(0) < num_evidences.unsqueeze(1)
        mask = mask.float().unsqueeze(-1)  # [NT, max_evidences, 1]

        text_global_emb = (evidence_embeddings * mask).sum(dim=1) / (
                    num_evidences.float().unsqueeze(-1) + 1e-8)  # [NT, embedding_dim]

        return text_global_emb

    def compute_weighted_infonce_loss(self, image_global_emb, text_global_emb, relation_matrix):

        if image_global_emb is None or text_global_emb is None:
            return torch.tensor(0.0, device=relation_matrix.device)

        NI, D = image_global_emb.shape
        NT, _ = text_global_emb.shape

        relation_matrix = torch.clamp(relation_matrix, min=0.0)  # [NI, NT]

        image_emb_norm = F.normalize(image_global_emb, p=2, dim=-1)
        text_emb_norm = F.normalize(text_global_emb, p=2, dim=-1)

        similarity_matrix = torch.matmul(image_emb_norm, text_emb_norm.t()) / self.info_nce_temperature  # [NI, NT]

        log_sum_exp_i = torch.logsumexp(similarity_matrix, dim=1, keepdim=True)  # [NI, 1]
        log_probs_i2t = similarity_matrix - log_sum_exp_i  # [NI, NT]
        loss_terms_i2t = -log_probs_i2t

        weighted_loss_i2t = relation_matrix * loss_terms_i2t  # [NI, NT]
        loss_i2t = weighted_loss_i2t.sum() / (relation_matrix.sum() + 1e-8)

        similarity_matrix_t = similarity_matrix.t()  # [NT, NI]
        relation_matrix_t = relation_matrix.t()  # [NT, NI]

        log_sum_exp_t = torch.logsumexp(similarity_matrix_t, dim=1, keepdim=True)  # [NT, 1]
        log_probs_t2i = similarity_matrix_t - log_sum_exp_t  # [NT, NI]
        loss_terms_t2i = -log_probs_t2i  # [NT, NI]

        weighted_loss_t2i = relation_matrix_t * loss_terms_t2i  # [NT, NI]
        loss_t2i = weighted_loss_t2i.sum() / (relation_matrix_t.sum() + 1e-8)
        total_loss = (loss_i2t + loss_t2i) / 2.0

        return total_loss




