from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, CLIPProcessor, CLIPModel
import torch
import numpy as np
from PIL import Image
import cv2
import os

class Llama():
    def __init__(self):
        model_id = "nvidia/Llama3-ChatQA-1.5-8B"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")

    def get_formatted_input(self, messages):
        system = "You are an agent to answer my question."
        instruction = "According to the description, please output 0 if the agent stands normally, and 1 if the agent flips over, if it is unknown output 2."

        for item in messages:
            if item['role'] == "user":
                ## only apply this instruction for the first user turn
                item['content'] = instruction + " " + item['content']
                break

        conversation = '\n\n'.join(["User: " + item["content"] if item["role"] == "user" else "Assistant: " + item["content"] for item in messages]) + "\n\nAssistant:"
        formatted_input = system + "\n\n" + conversation
        
        return formatted_input

    def get_answer(self, msg):
        messages = [
        {"role": "user", "content": msg}
        ]
        formatted_input = self.get_formatted_input(messages)
        tokenized_prompt = self.tokenizer(self.tokenizer.bos_token + formatted_input, return_tensors="pt").to(self.model.device)
        terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model.generate(input_ids=tokenized_prompt.input_ids, attention_mask=tokenized_prompt.attention_mask, max_new_tokens=128, eos_token_id=terminators)

        response = outputs[0][tokenized_prompt.input_ids.shape[-1]:]
        answer = self.tokenizer.decode(response, skip_special_tokens=True)
        return answer

class MiniCPM():
    def __init__(self):
        self.model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
            attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
        self.model = self.model.eval().cuda()
        self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)

    def get_answer(self, img_rgb_array, question):
        image = Image.fromarray(img_rgb_array)
        msgs = [{'role': 'user', 'content': [image, question]}]
        answer = self.model.chat(
            image=None,
            msgs=msgs,
            tokenizer=self.tokenizer
        )
        return answer

class CLIP():
    def __init__(self, reward_type='onehot'):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

        # ------------- textual intentions for DMC Cheetah ------------------
        # experiments in Section 5.2.1
        self.texts = [['The simulated two-legs robot flips over', 'The simulated two-legs robot stands normally']]
        # experiments in Section 5.2.2
        #self.texts = [['The underneath plane is Yellow-Orange', 'The underneath plane is Green-Blue']]

        # using multiple intentions simultaneously in DMC Cheetah
        #self.texts = [['The simulated two-legs robot flips over', 'The simulated two-legs robot stands normally'], ['The underneath plane is Yellow-Orange', 'The underneath plane is Green-Blue']]
        #self.texts = [['The simulated two-legs robot flips over', 'The simulated two-legs robot stands normally'], ['The underneath plane is Green-Blue', 'The underneath plane is Yellow-Orange']]

        # ------------- textual intentions for DMC Quadruped ------------------
        # experiments in Appendix 5
        #self.texts = [['The simulated quadruped flips over', 'The simulated quadruped stands normally']]
        # experiments in Section 5.2.2
        #self.texts = [['The underneath plane is Green-Blue', 'The underneath plane is Pink-Purple']]
        #self.texts = [['The underneath plane is Pink-Purple', 'The underneath plane is Green-Blue']]

        # ------------- textual intentions for DMC Humanoid ------------------
        # experiments in Section 5.2.3
        #self.texts = [['The simulated humanoid robot is stretched', 'The simulated humanoid robot is twisted']]


        # ------------- textual intentions for Franka Kitchen ------------------
        #self.texts = [['The simulated robot arm is twisted', 'The simulated robot arm is stretched']]
        #self.texts = [['The simulated white robot arm is in the left part of the scene', 'The simulated white robot arm is in the right part of the scene']]
        #self.texts = [['The simulated robot arm is stretched', 'The simulated robot arm is twisted']]
        self.reward_type = reward_type

    def get_answer(self, img_rgb_array):
        #image = Image.fromarray(img_rgb_array)
        image = img_rgb_array
        inputs = self.processor(text=self.texts, images=image, return_tensors="pt", padding=True).to(self.device)
        outputs = self.model(**inputs)
        logits_per_image = outputs.logits_per_image # this is the image-text similarity score
        probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
        label = probs.argmax().item()
        return self.texts[label]

    def get_batch_answers(self, batch_rgb_array):
        # batch_rgb_array: [batch_size, h, w, c]
        skip = 2
        #skip = 20
        labels = []
        for text in self.texts:
            with torch.no_grad():
                inputs = self.processor(text=text, images=batch_rgb_array[::skip,], return_tensors="pt", padding=True).to(self.device)
                outputs = self.model(**inputs)
                logits_per_image = outputs.logits_per_image # this is the image-text similarity score
                probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
            if self.reward_type == 'onehot':
                label = probs.argmax(dim=1).cpu().numpy()
                labels.append(label)
            elif self.reward_type == 'prob':
                label = probs[:,1].cpu().detach().numpy()
                labels.append(label)
            else:
                raise NotImplementedError
        merged_lable = np.logical_and.reduce(labels).astype(int)
        batch_all_labels = np.repeat(merged_lable, skip) # repeat for the unlablled ones.
        return batch_all_labels
