import torch
import torch.nn as nn
import torch.nn.functional as F
from lgdea.models.text_model import BertEncoder


class DiagnosticSemanticPrototypeModel(nn.Module):

    def __init__(self, cfg):
        super(DiagnosticSemanticPrototypeModel, self).__init__()

        self.embedding_dim = cfg.model.text.embedding_dim

        self.text_encoder = BertEncoder(cfg)

        self.num_prototypes = getattr(cfg.model, 'num_prototypes', 64)
        self.prototypes = nn.Parameter(
            torch.randn(self.num_prototypes, self.embedding_dim)
        )
        nn.init.xavier_uniform_(self.prototypes)

        self.temperature = getattr(cfg.model, 'temperature', 0.1)

        self.projection = nn.Linear(self.embedding_dim, self.embedding_dim)

    def forward(self, input_ids, attention_mask, token_type_ids, num_evidences):

        B, max_evidences, max_length = input_ids.shape

        input_ids_flat = input_ids.view(-1, max_length)
        attention_mask_flat = attention_mask.view(-1, max_length)
        token_type_ids_flat = token_type_ids.view(-1, max_length)

        _, evidence_embeddings_flat, _ = self.text_encoder(
            ids=input_ids_flat,
            attn_mask=attention_mask_flat,
            token_type=token_type_ids_flat
        )

        evidence_embeddings_flat = self.projection(evidence_embeddings_flat)
        evidence_embeddings = evidence_embeddings_flat.view(B, max_evidences, self.embedding_dim)

        evidence_norm = F.normalize(evidence_embeddings, p=2, dim=-1)
        prototype_norm = F.normalize(self.prototypes, p=2, dim=-1)

        similarity = torch.matmul(evidence_norm, prototype_norm.t()) / self.temperature
        prototype_assignments = F.softmax(similarity, dim=-1)

        reconstructed_embeddings = torch.matmul(prototype_assignments, prototype_norm)

        mask = torch.arange(max_evidences, device=num_evidences.device).unsqueeze(0) < num_evidences.unsqueeze(1)
        mask = mask.float().unsqueeze(-1)

        report_prototype_dist = (prototype_assignments * mask).sum(dim=1) / (num_evidences.float().unsqueeze(-1) + 1e-8)
        report_representations = torch.matmul(report_prototype_dist, prototype_norm)

        return {
            'evidence_embeddings': evidence_embeddings,
            'reconstructed_embeddings': reconstructed_embeddings,
            'report_representations': report_representations
        }

    def compute_reconstruction_loss(self, evidence_embeddings, reconstructed_embeddings, num_evidences):
        max_evidences = evidence_embeddings.shape[1]
        bal_loss = 0.0

        mask = torch.arange(max_evidences, device=num_evidences.device).unsqueeze(0) < num_evidences.unsqueeze(1)
        mask = mask.float().unsqueeze(-1)

        evidence_norm = F.normalize(evidence_embeddings, p=2, dim=-1)
        reconstructed_norm = F.normalize(reconstructed_embeddings, p=2, dim=-1)
        cosine_sim = (evidence_norm * reconstructed_norm).sum(dim=-1, keepdim=True)
        reconstruction_loss = (1 - cosine_sim) * mask

        total_valid = mask.sum()
        if total_valid > 0:
            return reconstruction_loss.sum() / total_valid
        else:
            return torch.tensor(0.0, device=evidence_embeddings.device)

