import torch
import torch.nn as nn
from transformers import Qwen2Model

class Relation_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Relation_Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

class Qwen_relevence_atten_model_span(nn.Module):
    def __init__(self, input_size, hidden_size, output_size = 24, llm = None):
        super(Qwen_relevence_atten_model_span, self).__init__()
        if not llm:
            self.embedding = Qwen2Model.from_pretrained(
                "Qwen/Qwen3-0.6B",
                attn_implementation="eager"
            ).to(torch.bfloat16)
        else:
            self.embedding = llm

        self.classifier = Relation_Classifier(input_size, hidden_size, output_size).to(self.embedding.device).to(torch.bfloat16)
        self.emb_dim = self.embedding.config.hidden_size
        self.emb_attn_head = self.embedding.config.num_attention_heads
        emb_dim = self.emb_dim
        
        self.fused_method = "mlp(emb)"
        if self.fused_method == "mlp(last_token||attention)||emb" :
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim + self.emb_attn_head, emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim//2 )
            ).to(torch.bfloat16)
        elif self.fused_method == "mlp(emb)":
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim)
            ).to(torch.bfloat16)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(emb_dim , emb_dim),
                nn.ReLU(),
                nn.Linear(emb_dim, emb_dim//2)
            ).to(torch.bfloat16)

        self.loss_fn = nn.BCEWithLogitsLoss()
        print("using span only model")

    def get_span_and_content_emb(self, embs, subject_spans, object_spans, attn, context_window=5):
        batch_size,  _ = subject_spans.shape
        _, seq_len, embedding_dim = embs.shape
        
        try:
            seq_len = embs.size(1)
            subject_spans = torch.clamp(subject_spans, 0, seq_len - 1)
            object_spans = torch.clamp(object_spans, 0, seq_len - 1)
            sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
            obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
        except:
            print(subject_spans)

        sub_embs = embs.gather(dim=1,index=subject_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
        obj_embs = embs.gather(dim=1,index=object_spans.unsqueeze(-1).expand(-1, -1, embs.size(-1)))
        sub_embs = (sub_embs[:, 0, :] + sub_embs[:, 1, :]) /2 
        obj_embs = (obj_embs[:, 0, :] + obj_embs[:, 1, :]) /2 

        return sub_embs, obj_embs

    def forward(self,input_ids,attention_mask,span_info,labels):
        outputs = self.embedding.forward(input_ids,attention_mask,output_attentions=True)
        embs,attn = outputs.last_hidden_state,outputs.attentions[-1]
        relations = span_info

        relation_label = torch.tensor(labels).to('cuda')  
        all_subj_span = []
        all_obj_span = []
        for rel in relations:
            subj_span = rel['subj_tok_span']
            obj_span = rel['obj_tok_span']
            all_subj_span.append(subj_span)
            all_obj_span.append(obj_span)
        all_subj_span = torch.tensor(all_subj_span, dtype=torch.long, device=embs.device)
        all_obj_span = torch.tensor(all_obj_span, dtype=torch.long, device=embs.device)

        sub_embs, obj_embs= self.get_span_and_content_emb(embs, all_subj_span, all_obj_span, attn=attn)
        
        fused_method = self.fused_method

        if fused_method == "mlp(last_token||attention)||emb":
            sub_mlp_embs = self.mlp()
        elif fused_method == "mlp(emb)":
            sub_mlp_embs = self.mlp(sub_embs)
            obj_mlp_embs = self.mlp(obj_embs)
            cat_embs = torch.cat([sub_mlp_embs, obj_mlp_embs], dim=-1)
            
        out = self.classifier(cat_embs)
        loss = self.loss_fn(out,relation_label)
        return {"loss":loss,"logits":out}