import torch
import torch.nn as nn
from transformers import Qwen2Model

class Relation_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relation_Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class Qwen_relevence_atten_model_span(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=24, llm=None):
        super(Qwen_relevence_atten_model_span, self).__init__()
        if not llm:
            self.embedding = Qwen2Model.from_pretrained(
                "Qwen/Qwen2-0.6B",
                attn_implementation="eager"
            ).to(torch.bfloat16)
        else:
            self.embedding = llm

        self.classifier = Relation_Classifier(input_size, hidden_size, output_size).to(self.embedding.device).to(torch.bfloat16)
        self.emb_dim = self.embedding.config.hidden_size
        self.emb_attn_head = self.embedding.config.num_attention_heads
        emb_dim = self.emb_dim
        
        self.fused_method = "mlp(emb_cls)"
        if self.fused_method == "mlp(last_token||attention)||emb":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim + self.emb_attn_head, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim//2)
            ).to(torch.bfloat16)
        elif self.fused_method == "mlp(emb)" or self.fused_method == "mlp(emb_cls)":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim//2)
            ).to(torch.bfloat16)

        self.loss_fn = nn.BCEWithLogitsLoss()
        print("using span only model")

    def get_span_and_content_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        batch_size, _ = subject_spans.shape
        seq_len = embs.size(1)
        
        subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
        object_spans = torch.clamp(object_spans, 0, seq_len - 1)
        
        sub_embs = embs.gather(dim=1, index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
        obj_embs = embs.gather(dim=1, index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
        
        sub_embs = (sub_embs[:, 0, :] + sub_embs[:, 1, :]) / 2
        obj_embs = (obj_embs[:, 0, :] + obj_embs[:, 1, :]) / 2

        return sub_embs, obj_embs

    def forward(self, input_ids, attention_mask, span_info, labels):
        outputs = self.embedding.forward(input_ids, attention_mask, output_attentions=True)
        embs, attn = outputs.last_hidden_state, outputs.attentions[-1]
        relations = span_info
        batch_size = input_ids.size(0)
        sequence_lengths = attention_mask.sum(dim=1) - 1
        last_token = embs[torch.arange(batch_size), sequence_lengths]
        relation_label = torch.tensor(labels).to('cuda')
        
        all_subj_span = []
        all_obj_span = []
        for rel in relations:
            all_subj_span.append(rel['subj_tok_span'])
            all_obj_span.append(rel['obj_tok_span'])
            
        all_subj_span = torch.tensor(all_subj_span, dtype=torch.long, device=embs.device)
        all_obj_span = torch.tensor(all_obj_span, dtype=torch.long, device=embs.device)

        sub_embs, obj_embs = self.get_span_and_content_emb(embs, all_subj_span, all_obj_span, attn=attn)
        
        sub_mlp_embs = self.mlp(sub_embs)
        obj_mlp_embs = self.mlp(obj_embs)
        cat_embs = torch.cat([sub_mlp_embs, obj_mlp_embs, last_token], dim=-1)
        
        out = self.classifier(cat_embs)
        loss = self.loss_fn(out, relation_label)
        return {"loss": loss, "logits": out}