from typing import List, Dict, Union
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from vllm import LLM, SamplingParams
import torch
import torch.nn.functional as F
from together import Together
from utils.misc import preprocess_tensor

transformers.logging.set_verbosity_error()

model_hf = {
    "Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Qwen2-7B-Instruct": "Qwen/Qwen2-7B-Instruct",
    "Qwen2-1.5B-Instruct": "Qwen/Qwen2-1.5B-Instruct",
    "Qwen2-0.5B-Instruct": "Qwen/Qwen2-0.5B-Instruct",
}

llm_context_window = {
    "Meta-Llama-3.1-8B-Instruct": 131072,
    "Qwen2-7B-Instruct": 32768,
    "Qwen2-1.5B-Instruct": 32768,
    "Qwen2-0.5B-Instruct": 32768,
}


class ActorLM(torch.nn.Module):
    def __init__(self, transformers_mode="transformerss", llm="Qwen2-0.5B-Instruct", context_window=896,
                 device="cuda" if torch.cuda.is_available() else "cpu", model_dir=None):
        super().__init__()
        self.llm = llm
        self.max_length = context_window
        self.device = device
        self.mode = transformers_mode
        model_dir = "model_hub" if model_dir is None else model_dir
        if self.mode == "transformers":
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True)
                self.model = AutoModelForCausalLM.from_pretrained(f'{model_dir}/{self.llm}',
                                                                trust_remote_code=True).to(self.device)
            except:
                self.tokenizer = AutoTokenizer.from_pretrained(f'{model_hf[self.llm]}',
                                                            cache_dir=f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True)
                self.model = AutoModelForCausalLM.from_pretrained(f'{model_hf[self.llm]}',
                                                                cache_dir=f'{model_dir}/{self.llm}',
                                                                trust_remote_code=True).to(self.device)

        elif self.mode == "vllm":
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}', trust_remote_code=True)
            self.engine = LLM(
                model=f'{model_dir}/{self.llm}',
                device=self.device,
                gpu_memory_utilization=0.5,
            )

        else:
            raise ValueError(f"only support transformers or vllm mode, got {self.mode}.")

    def forward(self, messages:List[Dict]):
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        if self.mode == "transformers":
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_k=50,
                    top_p=0.95,
                    output_scores=True,
                    return_dict_in_generate=True
                )
            logits = outputs.scores
            generated_ids = outputs.sequences[0][-len(logits):]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

            return generated_text

        elif self.mode == "vllm":
            sampling_params = SamplingParams(
                max_tokens=512,
                temperature=0.7,
                top_k=50,
                top_p=0.95,
            )
            outputs = self.engine.generate(prompt, sampling_params, use_tqdm=False)
            
            return outputs[0].outputs[0].text

        else:
            raise ValueError(f"only support transformers or vllm mode, got {self.mode}.")


class ActorLM_API:
    def __init__(self, llm="Qwen2-0.5B-Instruct", context_window=896):
        self.llm = model_hf[llm]+"-Turbo"
        self.context_window = context_window

        # self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
        self.client = Together(api_key='67a066cd01d6a33fa309bfdca9aac5d661f66303317f9b4ef8cddc9649652d8e')

    def forward(self, messages:List[Dict]):
        response = self.client.chat.completions.create(
            model=self.llm,
            messages=messages,
            max_tokens=512,
            temperature=0.7,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1,
            stop=["<|eot_id|>","<|eom_id|>"],
            stream=False
        )
        return response.choices[0].message.content


class InstructLM(torch.nn.Module):
    def __init__(self, llm="Qwen2-0.5B-Instruct", context_window=896,
                 device="cuda" if torch.cuda.is_available() else "cpu", model_dir=None):
        super().__init__()
        self.llm = llm
        self.max_length = context_window
        self.device = device# if device == "cpu" else device[:-1] + str(int(device[-1]) + 1)
        model_dir = "model_hub" if model_dir is None else model_dir
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_dir}/{self.llm}',
                                                        trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True).to(self.device)
        except:
            self.tokenizer = AutoTokenizer.from_pretrained(f'{model_hf[self.llm]}',
                                                        cache_dir=f'{model_dir}/{self.llm}',
                                                        trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(f'{model_hf[self.llm]}',
                                                            cache_dir=f'{model_dir}/{self.llm}',
                                                            trust_remote_code=True).to(self.device)
        self.top_k = 50
        self.vocab_size = self.model.config.vocab_size
        self.value_head = torch.nn.Linear(self.model.config.hidden_size, 1)

        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.value_head.parameters():  # ensure the value head is trainable
            param.requires_grad = True

        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            # target_modules=["<target_module_name>"]
        )

        self.model = get_peft_model(self.model, lora_config)
    
    def actor_parameters(self):
        """Returns the parameters of the model excluding the value head parameters."""
        return (param for name, param in self.model.named_parameters() if "value_head" not in name)

    def critic_parameters(self):
        """Returns the parameters of the value head only."""
        return self.value_head.parameters()
    
    def forward(self, messages: List[Dict]):
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False,
                                                    add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                inputs.input_ids,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_k=self.top_k,
                top_p=0.95,
                output_scores=True,
                return_dict_in_generate=True,
                output_hidden_states=True
            )

        generated_ids = outputs.sequences[0][-len(outputs.scores):]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

        # get top_k=50 logits
        locs, logits = [], []
        for score in outputs.scores:
            top_k = torch.topk(score, k=self.top_k, dim=-1)
            locs.append(top_k.indices)  # Token IDs
            logits.append(top_k.values)
        locs, logits = torch.stack(locs, dim=1).squeeze(dim=0), torch.stack(logits, dim=1).squeeze(dim=0)

        # get log_probs for each token
        log_probs = []
        for logit in logits:
            log_prob = F.log_softmax(logit, dim=-1)
            log_prob = torch.where(log_prob == -float('inf'), torch.log(torch.tensor(1e-10)), log_prob) # cut off at log(1e-10) to prevent nan
            log_probs.append(log_prob)
        log_probs = torch.cat((torch.stack(log_probs, dim=0), locs), dim=1)

        # compute value through value head for each token using its hidden state
        last_hidden_states = []
        for token_hidden_state in outputs.hidden_states:
            last_hidden_states.append(token_hidden_state[-1].squeeze())
        last_hidden_states[0] = last_hidden_states[0][-1].squeeze() # for first new token, only take the hidden state from last token
        last_hidden_states = torch.stack(last_hidden_states, dim=0)
        values = self.value_head(last_hidden_states)

        # pad log_probs and values to length 512 with inf
        log_probs, values = preprocess_tensor(log_probs.detach().cpu()), preprocess_tensor(values.detach().cpu())

        return generated_text, log_probs, values


class DoubleLM(torch.nn.Module):
    def __init__(self, actor_lm: Union[ActorLM, ActorLM_API], instruct_lm: InstructLM) -> None:
        super().__init__()
        self.actor_lm = actor_lm
        self.instruct_lm = instruct_lm
    
    def forward(self, messages:List[Dict], mode='actor'):
        if mode=='actor':
            return self.actor_lm.forward(messages=messages)
        elif mode=='instruct':
            return self.instruct_lm.forward(messages=messages)
        else:
            raise ValueError("Only support 'actor' and 'instruct' modes!")