from typing import List, Dict, Union
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from vllm import LLM, SamplingParams
import torch
import torch.nn.functional as F
from together import Together
from utils.misc import preprocess_tensor

transformers.logging.set_verbosity_error()

model_hf = {
    "Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Qwen2-7B-Instruct": "Qwen/Qwen2-7B-Instruct",
    "Qwen2-1.5B-Instruct": "Qwen/Qwen2-1.5B-Instruct",
    "Qwen2-0.5B-Instruct": "Qwen/Qwen2-0.5B-Instruct",
}

llm_context_window = {
    "Meta-Llama-3.1-8B-Instruct": 131072,
    "Qwen2-7B-Instruct": 32768,
    "Qwen2-1.5B-Instruct": 32768,
    "Qwen2-0.5B-Instruct": 32768,
}


class ActorLM(torch.nn.Module):
    def __init__(self, transformers_mode="transformerss", llm="Qwen2-0.5B-Instruct", context_window=896,
                 device="cuda" if torch.cuda.is_available() else "cpu", model_dir=None):
        super().__init__()
        self.llm = llm
        self.max_length = context_window
        self.device = device
        self.mode = transformers_mode
        model_dir = "model_hub" if model_dir is None else model_dir
        if self.mode == "transformers":
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}',
                                                            padding_side='left', trust_remote_code=True)
                self.model = AutoModelForCausalLM.from_pretrained(f'{model_dir}/{self.llm}',
                                                                trust_remote_code=True).to(self.device)
            except:
                self.tokenizer = AutoTokenizer.from_pretrained(f'{model_hf[self.llm]}',
                                                            cache_dir=f'{model_dir}/{self.llm}',
                                                            padding_side='left', trust_remote_code=True)
                self.model = AutoModelForCausalLM.from_pretrained(f'{model_hf[self.llm]}',
                                                                cache_dir=f'{model_dir}/{self.llm}',
                                                                trust_remote_code=True).to(self.device)

        elif self.mode == "vllm":
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}', trust_remote_code=True)
            self.engine = LLM(
                model=f'{model_dir}/{self.llm}',
                device=self.device,
                gpu_memory_utilization=0.5,
            )

        else:
            raise ValueError(f"only support transformers or vllm mode, got {self.mode}.")

    def forward(self, messages):
        if not messages:
            raise ValueError("Input messages cannot be empty.")

        if isinstance(messages[0], dict):
            return self.forward_single(messages)

        elif isinstance(messages[0], list):
            return self.forward_multi(messages)

        else:
            raise ValueError("Invalid input format. Expected List[Dict] or List[List[Dict]].")
    
    def forward_single(self, messages:List[Dict]):
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        if self.mode == "transformers":
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_k=50,
                    top_p=0.95,
                    output_scores=True,
                    return_dict_in_generate=True
                )
            logits = outputs.scores
            generated_ids = outputs.sequences[0][-len(logits):]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

            return generated_text

        elif self.mode == "vllm":
            sampling_params = SamplingParams(
                max_tokens=512,
                temperature=0.7,
                top_k=50,
                top_p=0.95,
            )
            outputs = self.engine.generate(prompt, sampling_params, use_tqdm=False)
            
            return outputs[0].outputs[0].text

        else:
            raise ValueError(f"only support transformers or vllm mode, got {self.mode}.")

    def forward_multi(self, messages:List[List[Dict]]):
        prompts = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        if self.mode == "transformers":
            inputs = self.tokenizer(prompts, padding=True, padding_side='left', return_tensors="pt").to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_k=50,
                    top_p=0.95,
                    output_scores=True,
                    return_dict_in_generate=True
                )
            generated_texts = []
            for i in range(len(prompts)):
                generated_ids = outputs.sequences[i][-len(outputs.scores):]
                generated_texts.append(self.tokenizer.decode(generated_ids, skip_special_tokens=True))

            return generated_texts

        elif self.mode == "vllm":
            sampling_params = SamplingParams(
                max_tokens=512,
                temperature=0.7,
                top_k=50,
                top_p=0.95,
            )
            outputs = self.engine.generate(prompts, sampling_params, use_tqdm=False)
            
            return [outputs[i].outputs[0].text for i in range(len(prompts))]

        else:
            raise ValueError(f"only support transformers or vllm mode, got {self.mode}.")


class ActorLM_API:
    def __init__(self, llm="Qwen2-0.5B-Instruct", context_window=896):
        self.llm = model_hf[llm]+"-Turbo"
        self.context_window = context_window

        # self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
        self.client = Together(api_key='67a066cd01d6a33fa309bfdca9aac5d661f66303317f9b4ef8cddc9649652d8e')

    def forward(self, messages):
        if not messages:
            raise ValueError("Input messages cannot be empty.")

        if isinstance(messages[0], dict):
            return self.forward_single(messages)

        elif isinstance(messages[0], list):
            return self.forward_multi(messages)

        else:
            raise ValueError("Invalid input format. Expected List[Dict] or List[List[Dict]].")

    def forward_single(self, messages:List[Dict]):
        response = self.client.chat.completions.create(
            model=self.llm,
            messages=messages,
            max_tokens=512,
            temperature=0.7,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1,
            stop=["<|eot_id|>","<|eom_id|>"],
            stream=False
        )
        return response.choices[0].message.content

    def forward_multi(self, messages:List[List[Dict]]):
        generated_texts = []
        for message in messages:
            response = self.client.chat.completions.create(
                model=self.llm,
                messages=message,
                max_tokens=512,
                temperature=0.7,
                top_p=0.95,
                top_k=50,
                repetition_penalty=1,
                stop=["<|eot_id|>","<|eom_id|>"],
                stream=False
            )
            generated_texts.append(response.choices[0].message.content)

        return generated_texts


class InstructLM(torch.nn.Module):
    def __init__(self, llm="Qwen2-0.5B-Instruct", context_window=896,
                 device="cuda" if torch.cuda.is_available() else "cpu", model_dir=None):
        super().__init__()
        self.llm = llm
        self.max_length = context_window
        self.device = device# if device == "cpu" else device[:-1] + str(int(device[-1]) + 1)
        model_dir = "model_hub" if model_dir is None else model_dir
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}',
                                                        padding_side='left', trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True).to(self.device)
        except:
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_hf[self.llm]}',
                                                        cache_dir=f'{model_dir}/{self.llm}',
                                                        padding_side='left', trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(f'{model_hf[self.llm]}',
                                                            cache_dir=f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True).to(self.device)
        self.top_k = 50
        self.vocab_size = self.model.config.vocab_size

        for param in self.model.parameters():
            param.requires_grad = False

        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            # target_modules=["<target_module_name>"]
        )

        self.model = get_peft_model(self.model, lora_config)
    
    def forward(self, messages):
        if not messages:
            raise ValueError("Input messages cannot be empty.")

        if isinstance(messages[0], dict):
            return self.forward_single(messages)

        elif isinstance(messages[0], list):
            return self.forward_multi(messages)

        else:
            raise ValueError("Invalid input format. Expected List[Dict] or List[List[Dict]].")
    
    def forward_single(self, messages: List[Dict]):
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                inputs.input_ids,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_k=self.top_k,
                top_p=0.95,
                output_scores=True,
                return_dict_in_generate=True
            )
        generated_ids = outputs.sequences[0][-len(outputs.scores):]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

        # get top_k=50 logits
        locs, logits = [], []
        for score in outputs.scores:
            top_k = torch.topk(score, k=self.top_k, dim=-1)
            locs.append(top_k.indices)  # Token IDs
            logits.append(top_k.values)
        locs, logits = torch.stack(locs, dim=1).squeeze(dim=0), torch.stack(logits, dim=1).squeeze(dim=0)

        # get log_probs for each token
        log_probs = []
        for logit in logits:
            log_prob = F.log_softmax(logit, dim=-1)
            log_prob = torch.where(log_prob == -float('inf'), torch.log(torch.tensor(1e-10)), log_prob) # cut off at log(1e-10) to prevent nan
            log_probs.append(log_prob)
        log_probs = torch.cat((torch.stack(log_probs, dim=0), locs), dim=1)

        # pad log_probs and values to length 512 with inf
        log_probs = preprocess_tensor(log_probs.detach().cpu())

        return generated_text, log_probs

    def forward_multi(self, messages: List[List[Dict]]):
        prompts = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        inputs = self.tokenizer(prompts, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                inputs.input_ids,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_k=self.top_k,
                top_p=0.95,
                output_scores=True,
                return_dict_in_generate=True
            )
        generated_texts = []
        for i in range(len(prompts)):
            generated_ids = outputs.sequences[i][-len(outputs.scores):]
            generated_texts.append(self.tokenizer.decode(generated_ids, skip_special_tokens=True))

        all_log_probs = []
        for i in range(len(prompts)):
            # get top_k=50 logits
            locs, logits = [], []
            for scores in outputs.scores:
                score = scores[i]
                top_k = torch.topk(score, k=self.top_k, dim=-1)
                locs.append(top_k.indices)  # Token IDs
                logits.append(top_k.values)
            locs, logits = torch.stack(locs, dim=1).squeeze(dim=0), torch.stack(logits, dim=1).squeeze(dim=0)

            # get log_probs for each token
            log_probs = []
            for logit in logits:
                log_prob = F.log_softmax(logit, dim=-1)
                log_prob = torch.where(log_prob == -float('inf'), torch.log(torch.tensor(1e-10)), log_prob) # cut off at log(1e-10) to prevent nan
                log_probs.append(log_prob)
            log_probs = torch.cat((torch.stack(log_probs, dim=0), locs), dim=1)

            # pad log_probs and values to length 512 with inf
            log_probs = preprocess_tensor(log_probs.detach().cpu())
            all_log_probs.append(log_probs)

        return generated_texts, torch.stack(all_log_probs, dim=0)


class DoubleLM(torch.nn.Module):
    def __init__(self, actor_lm: Union[ActorLM, ActorLM_API], instruct_lm: InstructLM) -> None:
        super().__init__()
        self.actor_lm = actor_lm
        self.instruct_lm = instruct_lm
    
    def forward(self, messages:List[Dict], mode='actor'):
        if mode=='actor':
            return self.actor_lm.forward(messages=messages)
        elif mode=='instruct':
            return self.instruct_lm.forward(messages=messages)
        else:
            raise ValueError("Only support 'actor' and 'instruct' modes!")