import torch
import tqdm
import json
import time
import asyncio
import os
from importlib import import_module
from transformers import StoppingCriteria

class KeyWordsCriteria(StoppingCriteria):
    def __init__(self, stop_id_sequences):
        assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
        self.stop_sequences = stop_id_sequences

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        sequences_should_be_stopped = []
        for i in range(input_ids.shape[0]):
            sequence_should_be_stopped = False
            for stop_sequence in self.stop_sequences:
                if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence:
                    sequence_should_be_stopped = True
                    break
            sequences_should_be_stopped.append(sequence_should_be_stopped)
        return all(sequences_should_be_stopped)
    
    
def encode_with_prompt_completion_format_eval(example, max_seq_length):
    '''
    Here we assume each example has 'input' and 'output' fields.
    We concatenate input and output and craft the prompt only.
    '''
    # if prompt doesn't end with space and completion doesn't start with space, add space
    if not example['input'].endswith((' ', '\n', '\t')) and not example['output'].startswith((' ', '\n', '\t')):
        example_text = example['input'] + ' ' + example['output']
    else:
        example_text = example['input'] + example['output']
    return example_text
    
@torch.no_grad()
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
    generations = []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.cuda()
            attention_mask = attention_mask.cuda()

        try:
            batch_outputs = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
                **generation_kwargs
            )
        
            # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
            # so some outputs still have the stop sequence, which we need to remove.
            if stop_id_sequences:
                for output_idx in range(batch_outputs.shape[0]):
                    for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
                        if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
                            batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
                            break

            # remove the prompt from the output
            # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
            # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
            # space is important for some tasks (e.g., code completion).
            batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
            batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            # duplicate the prompts to match the number of return sequences
            batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
            batch_generations = [
                output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
            ]
        except Exception as e:
            print("Error when generating completions for batch:")
            print(batch_prompts)
            print("Error message:")
            print(e)
            print("Use empty string as the completion.")
            batch_generations = [""] * len(batch_prompts) * num_return_sequences

        generations += batch_generations

        # for prompt, generation in zip(batch_prompts, batch_generations):
        #     print("========")
        #     print(prompt)
        #     print("--------")
        #     print(generation)

        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)

    assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    return generations


@torch.no_grad()
def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, add_special_tokens=True, disable_tqdm=False):
    predictions, probs = [], []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i: i+batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask
        tokenizer.padding_side = "left"

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.cuda()
            attention_mask = attention_mask.cuda()

        batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
        batch_probs = torch.softmax(batch_logits, dim=-1)
        if candidate_token_ids is not None:
            batch_probs = batch_probs[:, candidate_token_ids]
        batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
        if return_token_predictions:
            if candidate_token_ids is not None:
                candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
                batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
            else:
                batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
            predictions += batch_predictions
        else:
            predictions += batch_prediction_indices.tolist()
        probs += batch_probs.tolist()

        if not disable_tqdm:
            progress.update(len(batch_prompts))

    assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
    return predictions, probs


@torch.no_grad()
def eval_nli_task(batch, model, tokenizer):
    choices = [choice[0] for choice in batch["choices"]] # [(L1, L1), ..., (Ln, Ln)]

    ground_truth_labels = torch.tensor([choices.index(output) for output in batch["output"]], dtype=torch.long)
    ground_truth_labels = ground_truth_labels.to(model.device)

    answer_choice_ids = [tokenizer(' ' + choice, add_special_tokens=False)['input_ids'][-1] for choice in choices]
    answer_choice_ids = torch.tensor(answer_choice_ids, dtype=torch.long).to(model.device)
    batch_input = tokenizer([input for input in batch["input"]], padding=True, return_tensors="pt").to(model.device)

    outputs = model(input_ids=batch_input["input_ids"], attention_mask=batch_input["attention_mask"], use_cache=False)
    
    batch_logits = outputs.logits[:, -1, :]
    batch_probs = torch.softmax(batch_logits, dim=-1)
    batch_label_probs = batch_probs[:, answer_choice_ids]
    batch_prediction_indices = torch.argmax(batch_label_probs, dim=-1) # (batch_size, )
    batch_prediction_indices = batch_prediction_indices.detach()

    return (batch_prediction_indices == ground_truth_labels).sum().float() 

@torch.no_grad()
def score_qa_task(model, tokenizer, scoring_examples, batch_size=1, aggregation="mul", disable_tqdm=False):
    '''
    Each scoring example is a dict, which contains the following keys:
    - input: the input to score
    - output: the output to score
    '''

    # add pad tokens to the tokenizer if it doesn't have one
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # unroll the scoring examples
    unrolled_examples = []
    if "label" not in scoring_examples[0]:
        for scoring_example in scoring_examples:
            input = scoring_example["input"]
            choices = scoring_example["choices"]
            label = choices.index(scoring_example["output"].strip())
            unrolled_examples.append({
                "input": input,
                "choices": choices,
                "label": label
            })
    else:
        print("Skipping unrolling scoring examples because they already contain labels.")
        unrolled_examples = scoring_examples

    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(unrolled_examples), desc="Scoring QA")

    accuracies = []
    for i in range(0, len(unrolled_examples), batch_size):
        batch_prompts = [example["input"] for example in unrolled_examples[i:i+batch_size]]
        tokenized_batch = tokenizer(batch_prompts, padding="longest", return_tensors="pt")
        if model.device.type == "cuda":
            tokenized_batch = {
                key: value.cuda() for key, value in tokenized_batch.items()
            }
        outputs = model(**tokenized_batch)

        for example_idx, (prompt, example) in enumerate(zip(batch_prompts, unrolled_examples[i:i+batch_size])):
            
            tokenized_prompt = tokenizer(example["input"], padding=False, return_tensors="pt").input_ids.squeeze(0) # (prompt_length, )
            tokenized_choices = tokenizer(example["choices"], padding="longest", return_tensors="pt").input_ids # (num_choices, prompt_length)
            # drop the first token of each choice, which is the space token
            tokenized_choices = tokenized_choices[:, 1:]
            
            tokenized_choices = tokenized_choices.to(model.device)

            choices_mask = tokenized_choices != tokenizer.pad_token_id # (num_choices, prompt_length)
            output_logit = outputs.logits[example_idx, :, :].unsqueeze(0).expand(tokenized_choices.shape[0], -1, -1) # (num_choices, prompt_length, vocab_size)

            if tokenizer.padding_side == "right":
                completion_logits = output_logit[:, len(tokenized_prompt)-1:len(tokenized_prompt)+len(tokenized_choices[0]), :] # (num_choices, num_tokens, vocab_size)

                tokenized_choices = tokenized_choices[:, :completion_logits.shape[1]] # (num_choices, num_tokens)
                choices_mask = choices_mask[:, :completion_logits.shape[1]] # (num_choices, num_tokens)
    
            else:
                # Calculate the start index for slicing
                start_index = -len(tokenized_choices[0])

                # Adjust the slicing of completion_logits to get logits for the tokenized choices
                completion_logits = output_logit[:, start_index:, :]  # (num_choices, num_tokens, vocab_size)

                # Adjust the slicing of tokenized_choices and choices_mask to match the shape of completion_logits
                tokenized_choices = tokenized_choices[:, start_index:]  # (num_choices, num_tokens)
                choices_mask = choices_mask[:, start_index:]  # (num_choices, num_tokens)
            
            # select the token likelihoods for the choices, in the shapre of (num_choices, num_tokens)
            completion_log_probs = torch.gather(completion_logits, dim=-1, index=tokenized_choices.unsqueeze(-1)).squeeze(-1) # (num_choices, num_tokens)
            # mask out the padding tokens
            completion_log_probs[~choices_mask] = 0
            # log sum exp
            completion_log_probs = torch.logsumexp(completion_log_probs, dim=-1) # (num_choices, )
            pred = torch.argmax(completion_log_probs).item() # (1, )

            label = example["label"]
            accuracies.append(int(pred == label))

        if not disable_tqdm:
            progress.update(len(batch_prompts))

    return sum(accuracies) / len(accuracies)


@torch.no_grad()
def score_completions(model, tokenizer, scoring_examples, batch_size=1, aggregation="sum", disable_tqdm=False):
    '''
    Each scoring example is a dict, which contains the following keys:
    - prompt: the prompt to score
    - completions: a list of completions to score
    '''
    
    # unroll the scoring examples
    unrolled_examples = []
    for scoring_example in scoring_examples:
        prompt = scoring_example["prompt"]
        for completion in scoring_example["completions"]:
            unrolled_examples.append({
                "prompt": prompt,
                "completion": completion
            })
    
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(unrolled_examples), desc="Scoring Completions")

    scores = []
    for i in range(0, len(unrolled_examples), batch_size):
        batch_prompts = [example["prompt"] for example in unrolled_examples[i:i+batch_size]]
        batch_examples = [
            (example["prompt"] if example["prompt"][-1] in ["\n", " "] else example["prompt"] + " ")
            + example["completion"] for example in unrolled_examples[i:i+batch_size]
        ]
        tokenized_batch = tokenizer(batch_examples, padding="longest", return_tensors="pt")
        if model.device.type == "cuda":
            tokenized_batch = {
                key: value.cuda() for key, value in tokenized_batch.items()
            }
        outputs = model(**tokenized_batch)

        for example_idx, (prompt, example) in enumerate(zip(batch_prompts, batch_examples)):
            tokenized_prompt = tokenizer(prompt, padding=False, return_tensors="pt").input_ids.squeeze(0)
            tokenized_example = tokenizer(example, padding=False, return_tensors="pt").input_ids.squeeze(0)
            completion_ids = tokenized_example[len(tokenized_prompt):]
            
            # get the logits for the entire example, removing the padding logits
            if tokenizer.padding_side == "right":
                example_logits = outputs.logits[example_idx, :len(tokenized_example), :]
            else:            
                example_logits = outputs.logits[example_idx, -len(tokenized_example):, :]

            # get the logits for the completion portion - note we need to shift the index left by 1 because logits are computed for the next token
            completion_logits = example_logits[len(tokenized_prompt)-1:len(tokenized_example)-1, :]
            completion_log_probs = torch.log_softmax(completion_logits, dim=-1)[range(len(completion_ids)), completion_ids]

            if aggregation == "sum":
                score = completion_log_probs.sum().item()
            elif aggregation == "mean":
                score = completion_log_probs.mean().item()
            elif aggregation == "max":
                score = completion_log_probs.max().item()
            else:
                raise ValueError("Invalid aggregation method: {}".format(aggregation))
            scores.append(score)

        if not disable_tqdm:
            progress.update(len(batch_examples))

    # roll up the scores
    rolled_up_scores = {}
    for unrolled_example, score in zip(unrolled_examples, scores):
        prompt = unrolled_example["prompt"]
        completion = unrolled_example["completion"]
        if prompt not in rolled_up_scores:
            rolled_up_scores[prompt] = {}
        rolled_up_scores[prompt][completion] = score

    return rolled_up_scores



def load_hf_lm_and_tokenizer(
        model_name_or_path, 
        tokenizer_name_or_path=None, 
        device_map="auto", 
        torch_dtype="auto",
        load_in_8bit=False, 
        convert_to_half=False,
        gptq_model=False,
        use_fast_tokenizer=True,
        padding_side="left",
    ):
    
    from transformers import AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, GPTNeoXForCausalLM

    if gptq_model:
        from auto_gptq import AutoGPTQForCausalLM
        model_wrapper = AutoGPTQForCausalLM.from_quantized(
            model_name_or_path, device="cuda:0", use_triton=True
        )
        model = model_wrapper.model  
    elif load_in_8bit:
        model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, 
            device_map=device_map, 
            load_in_8bit=True
        )
    else:
        if device_map:
            model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map, torch_dtype=torch_dtype)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch_dtype)
            if torch.cuda.is_available():
                model = model.cuda()
        if convert_to_half:
            model = model.half()
    model.eval()

    if not tokenizer_name_or_path:
        tokenizer_name_or_path = model_name_or_path
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer)
    except:
        # some tokenizers (e.g., GPTNeoXTokenizer) don't have the slow or fast version, so we just roll back to the default one
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
    # set padding side to left for batch generation
    tokenizer.padding_side = padding_side
    # set pad token to eos token if pad token is not set (as is the case for llama models)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # for OPT and Pythia models, we need to set tokenizer.model_max_length to model.config.max_position_embeddings 
    # to avoid wrong embedding index.    
    if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM):
        tokenizer.model_max_length = model.config.max_position_embeddings
        print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings))
        
    return model, tokenizer


def dynamic_import_function(function_path):
    '''
    Dynamically import a function from a path string (e.g., "module.submodule.my_function")
    '''
    module_path, function_name = function_path.rsplit(".", 1)
    module = import_module(module_path)
    function = getattr(module, function_name)
    return function
 