import torch
import torch.nn as nn
from .components import (
    ContactAreaProto, CoAttentionLayer
)
from ._model_config_base import ConfigBase
from ._lib import ModelLib

@ModelLib.register
class ConceptProtoClassifier(ConfigBase):
    def __init__(
        self,
        hidden_dim: int,
        sample_distances=[0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
        num_heads=1,
        dropout=0.1,
        chain_weights=[0.5,0.5],
        select_threshold=0.9,
        concept_nums={
            'a_v': 54,
            'a_j': 66,
            'b_v': 44,
            'b_j': 21,
            'tcr_species': 9,
            'mhc_allele': 172,
            'mhc_class': 4
        }
    ):
        super().__init__()
        self.ae_cap = ContactAreaProto(sample_distances=sample_distances, hidden_dim=hidden_dim, num_heads=num_heads, dropout=dropout)
        self.be_cap = ContactAreaProto(sample_distances=sample_distances, hidden_dim=hidden_dim, num_heads=num_heads, dropout=dropout)
        self.coatten = CoAttentionLayer(hidden_size=hidden_dim, num_heads=num_heads, dropout=dropout)
        self.concept = AlleleConcepts(hidden_dim=hidden_dim, num_heads=num_heads, dropout=dropout, select_threshold=select_threshold, concept_nums=concept_nums)
        self.chain_weights = chain_weights
        self.criterion_pc = nn.CrossEntropyLoss(weight=torch.tensor([0.2175, 0.7825]))
        self.concept_proto_weight = nn.Parameter(torch.tensor(0.5))
        
        self.loss = None

    def forward(
        self,
        hidden_states_a, attention_mask_a,
        hidden_states_b, attention_mask_b,
        hidden_states_e, attention_mask_e,
        mhc_allele_concepts, mhc_class_concepts,
        tcr_a_v_concepts, tcr_a_j_concepts,
        tcr_b_v_concepts, tcr_b_j_concepts,
        tcr_species_concepts,
        labels
    ):
        co_a, co_b = self.coatten(
            hidden_states_a,
            hidden_states_b,
            attention_mask_a,
            attention_mask_b,
        )
        w_ae = self.ae_cap(co_a, attention_mask_a, hidden_states_e, attention_mask_e)
        w_be = self.be_cap(co_b, attention_mask_b, hidden_states_e, attention_mask_e)
        w_ae *= self.chain_weights[0]
        w_be *= self.chain_weights[1]
        w_cp = (w_ae+w_be) # * mix_weight[0]
        w = w_cp
        
        outputs = torch.stack([1-w, w], dim=1)
        
        if self.training:
            outputs_cp = torch.stack([1-w_cp, w_cp], dim=1)
            _loss += self.criterion_pc(outputs_cp, labels.to(dtype=torch.long))
            self.loss = _loss
            _loss.backward()

        return outputs