from typing import Any

import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoConfig, Qwen2AudioForConditionalGeneration
from io import BytesIO
from urllib.request import urlopen
import librosa


# COT_PROMPT = "The following is a break down on the correctness and usefulness of the assistant's response to my question: "

COT_PROMPT='''
Please act as a professional speech quality evaluation expert and provide an objective description of the input audio based on the following four dimensions:
1.Noise: Whether there is background noise, and whether the noise interferes with the comprehension of the speech content.
2.Distortion: Whether the speech contains compression artifacts, electrical noise, or other types of distortion.
3.Naturalness: Whether the speech sounds natural and resembles real human speech.
4.Continuity: Whether the speech is fluent and continuous, or if there are any stutters, dropouts, or discontinuities.
Based on these four aspects, please generate a single, coherent, and objective paragraph that provides a comprehensive description of the overall quality of the audio.
'''

class CLoudRewardModelConfig(PretrainedConfig):
    """
    Configuration class for Reward Model.

    Args:
        base_model_name_or_path: Name of the base model
        **kwargs: Additional keyword arguments to be passed to the parent class constructor.
    """

    def __init__(self, feedback_method="vanilla", base_model_name_or_path="meta-llama/Meta-Llama-3-8B", **kwargs):
        
        assert feedback_method in ["vanilla", "teacher"]

        self.feedback_method = feedback_method
        self.base_model_name_or_path = base_model_name_or_path

        super().__init__(**kwargs)
    
class RewardHead(nn.Module):

    def __init__(self, cfg: PretrainedConfig, n_labels: int):
        super().__init__()
        self.dense = nn.Linear(cfg.text_config.hidden_size, cfg.text_config.hidden_size)
        # use same dropout as attention dropout
        self.dropout = nn.Dropout(cfg.text_config.attention_dropout)
        self.out_proj = nn.Linear(cfg.text_config.hidden_size, n_labels)

    def forward(self, hidden_states: torch.Tensor, **kwargs: Any):
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        output = self.out_proj(hidden_states)
        return output

class CLoudRewardModel(PreTrainedModel):
    config_class = CLoudRewardModelConfig

    def __init__(self, config, pretrained_reward_base_model=None):
        super().__init__(config)
        self.feedback_method = config.feedback_method

        if pretrained_reward_base_model is None:
            reward_base_model_cfg = AutoConfig.from_pretrained(config.base_model_name_or_path)
            self.reward_base_model = Qwen2AudioForConditionalGeneration(reward_base_model_cfg)
        else:
            self.reward_base_model = pretrained_reward_base_model
        self.reward_head = RewardHead(self.reward_base_model.config, 1)

        self._no_split_modules = self.reward_base_model._no_split_modules

    # Only used during training
    # Qwen2-audio中可传入labels，并已计算好loss
    def forward(self, input_ids, input_features, attention_mask, feature_attention_mask, labels):
        batch_size, _ = input_ids.shape

        output = self.reward_base_model(
            input_ids=input_ids,
            input_features=input_features,
            attention_mask=attention_mask,
            feature_attention_mask=feature_attention_mask,
            labels=labels,
            use_cache=False,    # fix bug: 训练时Qwen2会使use_cache=True，导致past_key_value非None，使key和value的shape：（batch_size，1，seq_len，2*seq_len），出现shape不匹配的问题（方法2会在backward时报错，方法1不会报错）
            output_hidden_states=True,
            return_dict=True
        )

        hidden_states = output.hidden_states[-1]
        rewards = self.reward_head(hidden_states)
        # sequence_lengths = torch.sum(output.attention_mask, dim=-1) - 1        # TODO: attention_mask是text的，qwen2-audio是把text和audio merge了，修改为output.attention_mask, 但是使用的是left padding，所以都是取最后一个token对应reward（最后两个token都是eos，该取哪一个的？？）
        sequence_lengths = output.attention_mask.shape[1]        # TODO: attention_mask是text的，qwen2-audio是把text和audio merge了，修改为output.attention_mask, 但是使用的是left padding，所以都是取最后一个token对应reward（最后两个token都是eos，该取哪一个的？？）
        rewards = rewards[torch.arange(batch_size, device=rewards.device), sequence_lengths-2]

        return rewards, output.logits, output.loss
        # return rewards, output.logits, output.loss, output.labels
        # return rewards, output.logits
        

    def prepare_inputs_for_reward(self, user_prompts, assistant_responses, processor, critique_prompt):
        add_generate_prompt = True
        if self.feedback_method == "vanilla":
            add_generate_prompt = False
            conversations = [
                [{"role": "user", "content": [
                    {"type": "audio", "audio_url": 'file://' + audio},
                    {"type": "text", "text": critique_prompt},
                ]},
                {"role": "assistant", "content": [
                    {"type": "text", "text": ""},
                ]}
                ]
                for user_prompt, audio in zip(user_prompts, assistant_responses)
            ]
        elif self.feedback_method == "teacher":
            add_generate_prompt = True
            conversations = [
                [{"role": "user", "content": [
                    {"type": "audio", "audio_url": 'file://' + audio},
                    {"type": "text", "text": critique_prompt},
                ]},
                # {"role": "assistant", "content": [
                #     {"type": "text", "text": ""},
                # ]}
                ]
                for user_prompt, audio in zip(user_prompts, assistant_responses)
            ]
        texts = [
            processor.apply_chat_template(conversation, add_generation_prompt=add_generate_prompt, tokenize=False) 
            for conversation in conversations
        ]
        audios = []
        for conversation in conversations:
            for message in conversation:
                if isinstance(message["content"], list):
                    for ele in message["content"]:
                        if ele["type"] == "audio":
                            audios.append(librosa.load(
                                BytesIO(urlopen(ele['audio_url']).read()), 
                                sr=processor.feature_extractor.sampling_rate)[0]
                            )
        inputs = processor(text=texts, audios=audios, return_tensors='pt', padding=True, sampling_rate=processor.feature_extractor.sampling_rate)
        if self.feedback_method == "teacher":
            pass
            # if inputs.input_ids[:, -2] == processor.tokenizer.eos_token_id:
            #     inputs.input_ids[:, -2] = ""
            # if inputs.input_ids[:, -1] == "\n":
            #     inputs.input_ids[:, -1] = ""
        return inputs.to(self.reward_base_model.device)

    @torch.inference_mode()
    def predict_reward(
        self,
        user_prompts,
        assistant_responses,
        processor,
        critique_prompt=COT_PROMPT,
        temp=0.0,
        max_tokens=8192,
    ):
        """
        args:
            user_prompts: List[str] -- list of user prompts
            assistant_responses: List[str] -- list of assistant responses
            tokenizer: Tokenizer -- tokenizer for the reward model
            critique_prompt: str -- prompt for generating critiques
            temp: float -- temperature for sampling critiques
            max_tokens: int -- maximum number of tokens to generate
        returns:
            rewards: torch.Tensor -- rewards for the assistant responses
            critiques: List[str] -- critiques for the assistant responses
        """
        reward_model_inputs = self.prepare_inputs_for_reward(user_prompts, assistant_responses, processor, critique_prompt)
        batch_size, input_text_seq_len = reward_model_inputs.input_ids.shape

        # eot_token_text = "<|eot_id|>"
        # eot_token_id = tokenizer.encode(eot_token_text, add_special_tokens=False)[0] # Hard coded for llama3 chat template

        if self.feedback_method == "vanilla":
            outputs = self.reward_base_model(
                input_ids=reward_model_inputs.input_ids,
                input_features=reward_model_inputs.input_features,
                attention_mask=reward_model_inputs.attention_mask,
                feature_attention_mask=reward_model_inputs.feature_attention_mask,
                output_hidden_states=True,
                return_dict=True
            )
            critiques = [""] * batch_size
        elif self.feedback_method == "teacher":
            generate_output = self.reward_base_model.generate(
                **reward_model_inputs,
                max_length=max_tokens,
                temperature=temp if temp > 0 else None,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                return_dict_in_generate=True,
                repetition_penalty=1.01
            )
            generate_ids = generate_output.sequences
            #TODO:
            # print(f"input_ids.shape: {reward_model_inputs.input_ids.shape}")
            # print(f"input_features.shape: {reward_model_inputs.input_features.shape}")
            # inputs = processor.batch_decode(reward_model_inputs.input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
            # print(f"inputs: {inputs}")
            # print(f"generate_ids.shape: {generate_ids.shape}")
            # generates = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
            # print(f"generates: {generates}")

            critique_ids = generate_ids[:, reward_model_inputs.input_ids.size(1):]
            # print(f"critique_ids.shape: {critique_ids.shape}")
            # print(f"critique_ids: {critique_ids}")

            critiques = processor.batch_decode(critique_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            # print(f"critiques: {critiques}")
            
            # critique_prompts = [formatted_prompt + critique + eot_token_text for formatted_prompt, critique in zip(formatted_prompts, critiques)]
            # reward_model_inputs = tokenizer(critique_prompts, add_special_tokens=False, return_tensors="pt", padding=True).to(self.reward_base_model.device)
            conversations = [
                [{"role": "user", "content": [
                    {"type": "audio", "audio_url": 'file://' + audio},
                    {"type": "text", "text": critique_prompt},
                ]},
                {"role": "assistant", "content": [
                    {"type": "text", "text": critique},
                ]}]
                for user_prompt, audio, critique in zip(user_prompts, assistant_responses, critiques)
            ]
            texts = [
                processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) 
                for conversation in conversations
            ]
            audios = []
            for conversation in conversations:
                for message in conversation:
                    if isinstance(message["content"], list):
                        for ele in message["content"]:
                            if ele["type"] == "audio":
                                audios.append(librosa.load(
                                    BytesIO(urlopen(ele['audio_url']).read()), 
                                    sr=processor.feature_extractor.sampling_rate)[0]
                                )
            inputs = processor(text=texts, audios=audios, return_tensors='pt', padding=True, sampling_rate=processor.feature_extractor.sampling_rate).to(self.reward_base_model.device)

            outputs = self.reward_base_model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )
        
        rewards = self.reward_head(outputs.hidden_states[-1][:, -2]).flatten().tolist()
        # rewards = self.reward_head(outputs.hidden_states[-1])
        # sequence_lengths = output.attention_mask.shape[1]        
        # rewards = rewards[torch.arange(batch_size, device=rewards.device), sequence_lengths-2].flatten().tolist()
        return rewards, critiques
    
    # @torch.inference_mode()
    # def predict_critic(
    #     self,
    #     user_prompts,
    #     assistant_responses,
    #     processor,
    #     critique_prompt=COT_PROMPT,
    #     temp=0.0,
    #     max_tokens=1024,
    # ):
    #     """
    #     args:
    #         user_prompts: List[str] -- list of user prompts
    #         assistant_responses: List[str] -- list of assistant responses
    #         tokenizer: Tokenizer -- tokenizer for the reward model
    #         critique_prompt: str -- prompt for generating critiques
    #         temp: float -- temperature for sampling critiques
    #         max_tokens: int -- maximum number of tokens to generate
    #     returns:
    #         critiques: List[str] -- critiques for the assistant responses
    #     """
    #     reward_model_inputs = self.prepare_inputs_for_reward(user_prompts, assistant_responses, processor, critique_prompt)
    #     # batch_size, input_text_seq_len = reward_model_inputs.input_ids.shape

    #     # eot_token_text = "<|eot_id|>"
    #     # eot_token_id = tokenizer.encode(eot_token_text, add_special_tokens=False)[0] # Hard coded for llama3 chat template

    #     if self.feedback_method == "csft":
    #         generate_ids = self.reward_base_model.generate(
    #             **reward_model_inputs,
    #             max_length=max_tokens,
    #             temperature=temp if temp > 0 else None,
    #             pad_token_id=processor.tokenizer.pad_token_id,
    #             eos_token_id=processor.tokenizer.eos_token_id,
    #             return_dict_in_generate=True,
    #             repetition_penalty=1.01
    #         )
            
    #         critique_ids = generate_ids[:, reward_model_inputs.input_ids.size(1):]
    #         critiques = processor.batch_decode(critique_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    #     return critiques


if __name__ == "__main__":
    from transformers import AutoTokenizer

    model_name = "ankner/Llama3-8B-CLoud-RM"
    # model_name = "ankner/Llama3-8B-Classic-RM"
    model = CLoudRewardModel.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

    user_prompt = ["Write me a story", "What is the capital of the moon?"]
    assistant_response = ["No I don't want to do that.", "Since the moon is made out of cheese, the capital is mozzerella."]

    rewards, critiques = model.predict_reward(user_prompt, assistant_response, tokenizer)
    for prompt, response, reward, critique in zip(user_prompt, assistant_response, rewards, critiques):
        print("Prompt:")
        print(prompt)
        print("Response:")
        print(response)
        print("Critique:")
        print(critique)
        print("Reward:")
        print(reward)
        print("=" * 100)
