import os
import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
from typing import Optional
import transformers
from transformers.utils import ModelOutput
from transformers import PreTrainedModel, TrainingArguments, Qwen2VLForConditionalGeneration, PretrainedConfig, LlavaForConditionalGeneration, MllamaForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass, field
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model
from sklearn.linear_model import Ridge
import joblib

class VLMRewardConfig(PretrainedConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
@dataclass
class RewardArgs(TrainingArguments):
     
    vision_tower: Optional[str] = field(
        default=None,
        metadata={"help": ("The vision tower to use.")},
    )
    max_length: Optional[int] = field(
        default=4096,
        metadata={"help": ("The maximum length of the input.")},
    )


@dataclass
class VLMRewardModelOutput(ModelOutput):
    loss: Optional[Tensor] = None
    logits: Tensor = None
    rewards: Tensor = None
    last_hidden_state: Tensor = None

class VLMRewardModel(PreTrainedModel):
    def __init__(self, args, config: VLMRewardConfig, ridge_model_path: Optional[str] = None):
        super(VLMRewardModel, self).__init__(config)
        self.config = config
        self.args = args
        self.model_name_or_path = args.model_name_or_path
        self.use_peft = args.use_peft
        self.peft_checkpoint_dir = args.peft_checkpoint_dir
        self.checkpoint_dir = args.checkpoint_dir

        # Load the backbone model
        if "qwen" in self.model_name_or_path.lower():
            self.backbone_model = Qwen2VLForConditionalGeneration.from_pretrained(self.model_name_or_path)            
        
        elif "llava" in self.model_name_or_path.lower():
            self.backbone_model = LlavaForConditionalGeneration.from_pretrained(self.model_name_or_path)
        
        elif "llama" in self.model_name_or_path.lower():
            self.backbone_model = MllamaForConditionalGeneration.from_pretrained(self.model_name_or_path)
            
        elif "internlm" in self.model_name_or_path.lower():
            self.backbone_model =  AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True)
            self.backbone_model.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
        else:
            self.backbone_model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True)
                
        # Initialize or load LoRA model if applicable
        if self.use_peft:
            self.lora_config = LoraConfig(
                target_modules=args.lora_target_modules,
                r=args.lora_r,
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout
            )
            if self.peft_checkpoint_dir is not None:
                self.backbone_model = PeftModel.from_pretrained(
                    self.backbone_model,
                    config=self.lora_config,
                    model_id=self.peft_checkpoint_dir
                )
            else:
                self.backbone_model = get_peft_model(self.backbone_model, self.lora_config)
            

        # Initialize the reward model head
        self.hidden_size = getattr(
            self.backbone_model.config, 
            "hidden_size",
            getattr(
                self.backbone_model.config,
                "d_model",
                getattr(
                    self.backbone_model.config,
                    "dim",
                    4096
                )
            )
        )
        self.reward_head = nn.Linear(self.hidden_size, 1)

        self.sigmoid = nn.Sigmoid()
        # Load trained Ridge regression model
        
        self.ridge_model = None
        if ridge_model_path:
            self.ridge_model = joblib.load(ridge_model_path)
        

    def forward(self, output_hidden_states=True, human_score=None, **kwargs):
        """
        Forward pass for the RewardModel. Computes rewards based on the hidden states
        of the backbone model.

        Args:
            input_ids (Tensor): Tokenized input IDs.
            attention_mask (Tensor): Attention masks for input IDs.
            pixel_values (Tensor): Preprocessed images as input.
            image_grid_thw (Tensor): Image grid metadata.
            labels (Tensor, optional): True rewards for supervised learning.
            return_dict (bool): Whether to return a dictionary or tuple.

        Returns:
            RewardModelOutput: Output of the RewardModel.
        """
        # Ensure the backbone model does not cache during training
        
        # self.backbone_model.config.use_cache = False
        

        # Forward pass through the backbone model
        if 'samples' in kwargs:
            samples = kwargs['samples']
            
            inputs, _, _ = self.backbone_model.interleav_wrap_chat(query=samples['text_input'], image=samples['image'].bfloat16())
            inputs = {
                k: v.to(self.device)
                for k, v in inputs.items() if torch.is_tensor(v)
            }
            
            outputs = self.backbone_model(
                **inputs,
                return_dict = True,
                output_hidden_states = True
            )            
            
        else:

            outputs = self.backbone_model(
                **kwargs,
                return_dict=True,
                output_hidden_states=True,
            )
            
        last_hidden_state = outputs.hidden_states[-1]
        logits = outputs.logits
        last_hidden_state = last_hidden_state + 0.0 * torch.mean(logits)
        last_hidden_state_at_the_end = last_hidden_state[:, -1, :]
        last_hidden_state_at_the_end = last_hidden_state_at_the_end.type_as(
        self.reward_head.weight)
        rewards = self.reward_head(last_hidden_state_at_the_end).squeeze(-1)
        rewards = self.sigmoid(rewards)
        
        if self.ridge_model:
            ridge_input = last_hidden_state_at_the_end.detach().cpu().numpy()
            ridge_rewards = self.ridge_model.predict(ridge_input)
            ridge_rewards = torch.tensor(ridge_rewards, dtype=torch.float32, device=self.device)
            return VLMRewardModelOutput(rewards=ridge_rewards, logits=logits, last_hidden_state=last_hidden_state)
        
        # return VLMRewardModelOutput(rewards=rewards, logits=logits, last_hidden_state=last_hidden_state)
        return VLMRewardModelOutput(rewards=rewards, logits=logits, last_hidden_state=last_hidden_state_at_the_end)
            

    def _set_gradient_checkpointing(self, module, value=False):
        module.gradient_checkpointing = value
        

