import torch
from torch import nn
import uuid
import os

from transformers import TextStreamer
from transformers import AutoModelForCausalLM, AutoTokenizer
from model.base import LargeMultimodalModel


class Qwen_VL(LargeMultimodalModel):
    """
    Qwen_VL Model
    https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/evaluate_vqa.py
    """

    def __init__(self, args) -> None:
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=self.device, trust_remote_code=True, fp32=True).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token_id = self.tokenizer.eod_id

        self.prompt = "<img>{}</img>{}"
        self.model.tie_weights()
        self.model.to(self.device)

        self.temperature = args.temperature
        self.top_p = None
        self.num_beams = args.num_beams
        self.use_cache = True
        self.max_new_tokens = 512

    def forward_with_probs(self, images, prompt):
        # visual_paths = []
        # # save images to /tmp, name generated by hash function
        # # qwen accept image path. Have to do it here....
        # for visual in images:
        #     name = uuid.uuid4().hex.upper()[0:6]
        #     visual.save(f"/tmp/{name}.png")
        #     visual_paths.append(f"/tmp/{name}.png")

        name = uuid.uuid4().hex.upper()[0:6]
        images[0].save(f"/tmp/{name}.png")
        visual_paths = [f"/tmp/{name}.png"]

        # Similar to llava, is visual paths has len 0
        # Then nothing will be executed
        query = []
        prompt = prompt.replace("<img>", "<image>")  # Replace <img> with <image> for Qwen
        if len(visual_paths) == 0:
            query.append({"text": prompt})
        else: 
            for visual_path in visual_paths:
                query.append({"image": visual_path})
            query.append({"text": prompt})

        questions = self.tokenizer.from_list_format(query)
        input_ids = self.tokenizer(questions, return_tensors="pt")

        streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
        pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eod_id

        outputs = self.model.generate(
            input_ids.input_ids.to(self.device),
            attention_mask=input_ids.attention_mask.to(self.device),
            eos_token_id=self.tokenizer.eod_id,
            pad_token_id=pad_token_id,
            do_sample=True if self.temperature > 0 else False,
            temperature=self.temperature,
            top_p=self.top_p,
            num_beams=self.num_beams,
            max_new_tokens=self.max_new_tokens,
            use_cache=self.use_cache,
            streamer=streamer,
            
            return_dict_in_generate=True,
            output_hidden_states=True,
            output_scores=True
        )

        logits = torch.cat(outputs['scores'], dim=0).float().cpu().numpy()
        probs = [nn.functional.softmax(next_token_scores, dim=-1) for next_token_scores in outputs['scores']]
        probs = torch.cat(probs).float().cpu().numpy()
        output_ids = outputs["sequences"][0][-len(probs):]
    
        response = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()

        output_ids = output_ids.cpu().numpy()

        hidden_states_all_layers = outputs['hidden_states'][0]
        hidden_states = hidden_states_all_layers[-1][0][[-1]].float()   # last layer, batch size=1, last token

        # remove visuals from tmp
        for visual_path in visual_paths:
            try:
                os.remove(visual_path)
            except:
                pass

        return response, output_ids, logits, probs, hidden_states
