from vllm import LLM, SamplingParams
from vllm.sampling_params import SamplingParams
from transformers import AutoProcessor, AutoTokenizer, pipeline
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from trl import AutoModelForCausalLMWithValueHead
from transformers.utils import cached_file
from safetensors import safe_open
from safetensors.torch import load_file
from qwen_vl_utils import process_vision_info
from qwen_omni_utils import process_mm_info
from math import ceil
from tqdm import tqdm
import torch
import logging
import asyncio
import uuid
import ray
import subprocess
import time
import socket
import os
# os.environ['VLLM_ALLOW_LONG_MAX_MODEL_LEN']=1


class BaseVLLM:

    async def __init__(self, args, actor_idx, model_path):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers
        self.tensor_parallel = args.tensor_parallel
        self.model_path = model_path

        self._init()
        logging.getLogger().setLevel(logging.ERROR)
    
    def _init(self):
        pass

    def generate(self, batch_dataset):
        
        try:
            conversations = []
            id2out = {}
            idx = 0
            for data in batch_dataset:
                # 
                id2out.setdefault(data.id, [])
                for conv in data.conversations:
                    conversations.append(conv)
                    if data.id in id2out:
                        id2out[data.id].append(idx)
                    else:
                        id2out[data.id] = [idx]
                    idx += 1
            llm_inputs = self.process_prompt(conversations)
            if not conversations:
                return batch_dataset
            
            llm_outputs = self.llm.generate(llm_inputs, sampling_params=self.sp)
            
            for data in batch_dataset:
                response_list = []
                for idx in id2out[data.id]:
                    response = llm_outputs[idx].outputs[0].text
                    response_list.append(response)
                
                data.llm_response = response_list
            return batch_dataset
        except:
            return batch_dataset


@ray.remote
class OmniVLLM(BaseVLLM):
    
    async def __init__(self, args, actor_idx, model_path):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers
        self.tensor_parallel = args.tensor_parallel
        self.model_path = model_path

        self._init()
    
    def _init(self):

        self.processor = AutoProcessor.from_pretrained(self.model_path)
        self.sp = SamplingParams(
            n=1,
            temperature=self.args.temperature,
            top_p=self.args.top_p,
            repetition_penalty=1.2,
            max_tokens=4096
        )
        
        self.llm = LLM(
            model=self.model_path, 
            dtype="auto",
            max_num_seqs=64,
            max_model_len=32768,
            tensor_parallel_size=self.tensor_parallel,
            limit_mm_per_prompt={"video": 4, "audio": 4, "image": 4},
        )
    
    def process_prompt(self, conversations):
    
        prompts = self.processor.apply_chat_template(
            conversations,
            tokenize=False,
            add_generation_prompt=True
        )
        llm_inputs = []

        for prompt, conversation in zip(prompts, conversations):
            audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
            multi_modal_data = {}
            if audios is not None:
                multi_modal_data['audio'] = audios
            if images is not None:
                multi_modal_data['image'] = images
            if videos is not None:
                multi_modal_data['video'] = videos
            
            if multi_modal_data:
                llm_input = dict(
                    prompt=prompt,
                    multi_modal_data=multi_modal_data
                )
            else:
                llm_input = dict(
                    prompt=prompt,
                )
            llm_inputs.append(llm_input)
        return llm_inputs


@ray.remote
class LanguageVLLM(BaseVLLM):

    async def __init__(self, args, actor_idx, model_path):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers
        self.tensor_parallel = args.tensor_parallel
        self.model_path = model_path

        self._init()
    
    def _init(self):

        self.processor = AutoTokenizer.from_pretrained(self.model_path)
        self.sp = SamplingParams(
            n=1,
            temperature=self.args.temperature,
            top_p=self.args.top_p,
            max_tokens=8192
        )
        
        self.llm = LLM(
            model=self.model_path, 
            dtype="auto",
            max_num_seqs=64,
            max_model_len=32768,
            # max_model_len=32790,
            tensor_parallel_size=self.tensor_parallel
        )
    
    def process_prompt(self, conversations):

        llm_inputs = self.processor.apply_chat_template(
            conversations,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        )
        return llm_inputs


@ray.remote
class VisionVLLM(BaseVLLM):

    async def __init__(self, args, actor_idx, model_path):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers
        self.tensor_parallel = args.tensor_parallel
        # self.tensor_parallel = 8
        self.tensor_parallel = 1
        self.model_path = model_path

        self._init()

    def _init(self):

        self.processor = AutoProcessor.from_pretrained(self.model_path)
        self.sp = SamplingParams(
            n=1,
            temperature=self.args.temperature,
            top_p=self.args.top_p,
            repetition_penalty=1.2,
            max_tokens=4096
        )
        
        self.llm = LLM(
            model=self.model_path, 
            dtype="auto",
            max_num_seqs=64,
            max_model_len=32768,
            tensor_parallel_size=self.tensor_parallel,
            limit_mm_per_prompt={"video": 4, "image": 2},
        )

        logging.getLogger().setLevel(logging.ERROR)
    
    def process_prompt(self, conversations):
        
        prompts = self.processor.apply_chat_template(
            conversations,
            tokenize=False,
            add_generation_prompt=True,
        )
        llm_inputs = []
        for prompt, conversation in zip(prompts, conversations):
            image_inputs, video_inputs, video_kwargs = process_vision_info(conversation, return_video_kwargs=True)
            mm_data = {}
            if image_inputs is not None:
                mm_data["image"] = image_inputs
            if video_inputs is not None:
                mm_data["video"] = video_inputs
            
            llm_input = {
                "prompt": prompt,
                "multi_modal_data": mm_data,
                "mm_processor_kwargs": video_kwargs,
            }
            llm_inputs.append(llm_input)
        return llm_inputs


@ray.remote
class ScalarRewardWorker:

    async def __init__(self, args, actor_idx):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers

        if hasattr(args, 'ranking_model'):
            self.model_path = args.ranking_model
        else:
            raise Exception(f"Current Mode Not Support {args.ranking_model}")

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.llm = pipeline(
            "sentiment-analysis",
            model=self.model_path,
            device="cuda",
            tokenizer=self.tokenizer
        )
        logging.getLogger().setLevel(logging.ERROR)
    
    def process_prompt(self, conversations):

        llm_inputs = self.tokenizer.apply_chat_template(
            conversations,
            tokenize=False,
            # add_generation_prompt=True,
        )
        return llm_inputs

    def generate(self, batch_dataset):
        
        conversations = []
        id2out = {}
        idx = 0
        for data in batch_dataset:
            for conv in data.conversations:
                conversations.append(conv)
                if data.id in id2out:
                    id2out[data.id].append(idx)
                else:
                    id2out[data.id] = [idx]
                idx += 1
        llm_inputs = self.process_prompt(conversations)
        pipe_kwargs = {
            "return_all_scores": True,
            "function_to_apply": "none",
            "batch_size": 1
        }
        llm_outputs = self.llm(llm_inputs, **pipe_kwargs)
        
        for data in batch_dataset:
            response_list = []
            for idx in id2out[data.id]:
                output = llm_outputs[idx]
                response_score = output[0]['score']
                response_list.append(response_score)
            data.llm_response = response_list
        return batch_dataset


@ray.remote(num_gpus=1)
class ScalarVisionRewardWorker:

    def __init__(self, args, actor_idx):
        
        self.args = args
        self.actor_idx = actor_idx
        self.workers = args.workers
        self.tensor_parallel = args.tensor_parallel

        if hasattr(args, 'ranking_model'):
            self.model_path = args.ranking_model
        else:
            raise Exception(f"Current Mode Not Support {args.ranking_model}")

        self.processor = AutoProcessor.from_pretrained(self.model_path)

        llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            self.model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )
        self.llm = AutoModelForCausalLMWithValueHead.from_pretrained(llm)

        vhead_file = cached_file(
            path_or_repo_id=self.model_path,
            filename="value_head.safetensors"
        )
        vhead_params = load_file(vhead_file, device="cuda")

        self.llm.load_state_dict(vhead_params, strict=False)
        self.llm.requires_grad_(False)
        self.llm.eval()
        logging.getLogger().setLevel(logging.ERROR)
    
    def process_prompt(self, conversations):

        text = self.processor.apply_chat_template(
            conversations,
            tokenize=False,
            add_generation_prompt=True,
        )

        image_inputs, video_inputs = process_vision_info(conversations)
        llm_inputs = self.processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        llm_inputs = llm_inputs.to("cuda")
        return llm_inputs

    def generate(self, batch_dataset):
        
        try:
            conversations = []
            llm_outputs = []
            id2out = {}
            idx = 0
            for data in batch_dataset:
                for conv in data.conversations:
                    conversations.append(conv)
                    if data.id in id2out:
                        id2out[data.id].append(idx)
                    else:
                        id2out[data.id] = [idx]
                    idx += 1
                
                    llm_inputs = self.process_prompt([conv])
                    llm_output = self.llm(**llm_inputs, return_dict=True, use_cache=False)[-1][:,-1]
                    llm_outputs.append(llm_output.item())

            for data in batch_dataset:
                response_list = []
                for idx in id2out[data.id]:
                    response_score = llm_outputs[idx]
                    response_list.append(response_score)
                data.llm_response = response_list
            return batch_dataset
        except:
            return batch_dataset
