
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel
from transformers import GPT2Tokenizer
from typing import Tuple, List, Union, Optional
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from torch.nn import CrossEntropyLoss
import numpy as np

device = "cuda:0" if torch.cuda.is_available() else "cpu"  # Use the remapped device index
class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class SimpleTransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_layers=2, num_heads=4, max_length=128, num_classes=2, dropout=0.1):
        super(SimpleTransformerClassifier, self).__init__()
        self.embed_dim = embed_dim
        self.max_length = max_length
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Classification head
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(embed_dim, num_classes)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Ensure that inputs longer than max_length are truncated
        if input_ids.size(1) > self.max_length:
            input_ids = input_ids[:, :self.max_length]
            if attention_mask is not None:
                attention_mask = attention_mask[:, :self.max_length]
        batch_size, seq_len = input_ids.size()
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)
        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.position_embedding(positions)
        
        embeddings = token_embeds + pos_embeds
        embeddings = self.dropout(embeddings)
        src_key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        transformer_out = self.transformer_encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1)
            summed = torch.sum(transformer_out * mask, dim=1)
            lengths = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1)
            pooled = summed / lengths
        else:
            pooled = transformer_out.mean(dim=1)
        logits = self.classifier(pooled)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        output = {"logits": logits}
        if loss is not None:
            output["loss"] = loss
        return output

    def forward_with_embeddings(self, embeddings, attention_mask, target_class=None):
        if embeddings.size(1) > self.max_length:
            embeddings = embeddings[:, :self.max_length]
            if attention_mask is not None:
                attention_mask = attention_mask[:, :self.max_length]
        batch_size, seq_len, _ = embeddings.size()
        positions = torch.arange(0, seq_len, device=embeddings.device).unsqueeze(0).expand(batch_size, seq_len)
        pos_embeds = self.position_embedding(positions)
        combined = embeddings + pos_embeds
        combined = self.dropout(combined)
        
        src_key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        transformer_out = self.transformer_encoder(combined, src_key_padding_mask=src_key_padding_mask)
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1)
            summed = torch.sum(transformer_out * mask, dim=1)
            lengths = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1)
            pooled = summed / lengths
        else:
            pooled = transformer_out.mean(dim=1)
        logits = self.classifier(pooled)
        if target_class is not None:
            return logits[:, target_class]
        return logits


class ClipCaptionModel(nn.Module):
    # def __init__(self, prefix_length: int, prefix_size: int = 512, device=device,bias_vector=None,lam=None):
    def __init__(self, prefix_length: int, prefix_size: int = 512,device=device):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = CustomGPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))

    def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,token_bias_scores=None, g_value=0):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        return self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask, token_bias_scores=token_bias_scores, g_value=g_value)
        

    def get_dummy_token(self, batch_size: int, device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.interim_hidden_state = None
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        hidden_states = transformer_outputs[0]

        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)
        self.interim_hidden_state = hidden_states

        lm_logits = self.lm_head(hidden_states)

        loss = None
      
        if labels is not None:
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
    

