import os

import torch
import requests
from PIL import Image
import numpy as np
import copy
import argparse
import sys

from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates


class ORM:
    def __init__(self, args):

        ckpt_path = args.orm_ckpt_path

        # Load model
        llava_model_args = {"multimodal": True}

        overwrite_config = {"image_aspect_ratio": "pad", "torch_dtype": 'bfloat16'}
        llava_model_args["overwrite_config"] = overwrite_config
        
        pretrained = ckpt_path
        print(f"pretrained path:{pretrained}")
        model_name = "llava_qwen"
        device_map = "cpu"
        self.tokenizer, self.model, self.image_processor, _ = load_pretrained_model(
            pretrained, None, model_name, device_map=device_map, torch_dtype='bfloat16', **llava_model_args
        )
        self.config = self.model.config

        self.model.eval()

        self.yes_token_id = self.tokenizer.convert_tokens_to_ids("yes")
        self.no_token_id = self.tokenizer.convert_tokens_to_ids("no")

    @property
    def __name__(self):
        return 'ORM'
    
    def load_to_device(self, load_device):
        self.device = load_device
        self.model.to(self.device)

        # freeze all parameters
        # for n, p in self.model.named_parameters():
        #     p.requires_grad = False
        # self.model.eval()

    def __call__(self, prompts, images, **kwargs):

        # Load the image
        results = []
        for prompt, image in zip(prompts, images):

            # Process the image
            image_tensor = process_images([image], self.image_processor, self.config)[0]
            image_tensor = image_tensor.to(dtype=torch.bfloat16, device=self.device)

            question = (f"{DEFAULT_IMAGE_TOKEN} This image is generated by a prompt: {prompt}. Does this image accurately represent the prompt? Please answer yes or no without explanation.")

            # Prepare conversation
            conv_template = "qwen_1_5"
            conv = copy.deepcopy(conv_templates[conv_template])
            conv.append_message(conv.roles[0], question)
            conv.append_message(conv.roles[1], None)
            prompt_question = conv.get_prompt()

            # Input question and image to the model
            input_ids = tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            image_size = image.size

            succeed = False
            max_retries = 1
            retry_count = 0


            while not succeed and retry_count < max_retries:
                retry_count += 1
                # Generate answer
                with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                    with torch.no_grad():
                        cont = self.model.generate(
                            input_ids,
                            images=[image_tensor],
                            image_sizes=[image_size],
                            do_sample=True,
                            temperature=0.3,
                            max_new_tokens=100,
                            return_dict_in_generate=True,
                            output_scores=True,
                        )

                sequences = cont.sequences
                cur_reponse = self.tokenizer.convert_ids_to_tokens(sequences[0])[0].lower().strip()

                if cur_reponse not in ['yes', 'no']:    break
                else:   succeed = True

                scores = torch.cat([score.unsqueeze(1) for score in cont.scores], dim=1)
                scores = torch.nn.functional.softmax(scores, dim=-1)
                first_token_prob = scores[0, 0] 
                yes_prob = first_token_prob[self.yes_token_id].item()
                no_prob = first_token_prob[self.no_token_id].item()
                # print("==>", cur_reponse, yes_prob, no_prob)
                # import ipdb; ipdb.set_trace()
                
            if not succeed:
                print("Failed to generate a valid 'yes' or 'no' answer after maximum retries. Reponse:" + cur_reponse)
                # return False, 0.
                results.append(0)
                continue
            
            results.append(yes_prob/(yes_prob+no_prob))

        return results
