import argparse
import glob
import json
import os
import copy
import time
import sys
import numpy as np
from typing import List, Tuple

import pandas as pd
import torch
import tqdm
import transformers
from sentence_transformers import SentenceTransformer

# import dataeval.TruthfulQA as TruthfulQA
import models
from func.metric import *
from semantic_entropy_util import (
    EntailmentDeberta,
    get_semantic_ids,
    logsumexp_by_id,
    predictive_entropy_rao,
    generate_responses,
    compute_semantic_entropy,
    ModelWrapper,

)


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='llama-8b-instruct')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--output', type=str, default="output.csv")
parser.add_argument('--input', type=str, default="input.csv") # the input data is in the csv format, where prompts are in a column
parser.add_argument('--column', type=str, default="prompt")
parser.add_argument('--thinking', type=str2bool, default=True)
parser.add_argument('--num_generations_per_prompt', type=int, default=10)
parser.add_argument('--temperature', type=float, default=1)
parser.add_argument('--decoding_method', type=str, default='greedy')
parser.add_argument('--top_p', type=float, default=0.99)
parser.add_argument('--top_k', type=int, default=10)
parser.add_argument('--seed', type=int, default=2023)
parser.add_argument('--nprocess', type=int, default=None)
parser.add_argument('--project_ind', type=int, default=0)
parser.add_argument('--max_new_tokens', type=int, default=2000)
parser.add_argument('--entailment_model', type=str, default='deberta',
                       choices=['deberta'],
                       help='Entailment model to use for semantic equivalence checking')




def str2bool(v):
    if isinstance(v, bool): return v
    if v.lower() in ('yes', 'true', 't', '1'): return True
    if v.lower() in ('no', 'false', 'f', '0'): return False
    raise argparse.ArgumentTypeError('Boolean value expected.')


args = parser.parse_args()

# load custom dataset
df = pd.read_csv(args.input)

# apply custom prompt template for Mistral-7B
def reformat_prompt(prompt):
    reformatted_prompt = f"""### Instruction:\n{prompt}.\n### Response:\n"""
    return reformatted_prompt
   


perplexities = []
energies = []
entropies = []
lexical_similarities = []
eigenscores = []
eigenvalues = []
eigenscore_outputs = []
batch_generations = []
eigenindicatorsv2 = []
semantic_entropies = []



def _generate_config(tokenizer):
    if tokenizer.__class__.__name__ == 'LlamaTokenizer':
        eos_token_id = [tokenizer.encode(_)[-1] for _ in ['.', '\n', '!', '?']] + [tokenizer.eos_token_id]
    elif tokenizer.__class__.__name__ == 'GPT2Tokenizer':
        eos_token_id = [tokenizer.encode(_)[1] for _ in ['.', '\n', '!', '?']] + [tokenizer.eos_token_id]
    elif tokenizer.__class__.__name__ == "PreTrainedTokenizerFast":
        eos_token_id = [tokenizer(_)['input_ids'][-1] for _ in ['.', '\n', '!', '?']] + [tokenizer.eos_token_id]
    else:
        raise NotImplementedError
    return dict(eos_token_id=eos_token_id)


@torch.no_grad()
def get_generations(model_name:str, args, seed=1, old_sequences=None, max_num_gen_once=args.num_generations_per_prompt):
    device = args.device
    model, tokenizer = models.load_model_and_tokenizer(model_name, args.device)
    SenSimModel = SentenceTransformer("sentence-transformers/all-roberta-large-v1")
    if args.entailment_model == 'deberta':
        entailment_model = EntailmentDeberta()
    utils.seed_everything(seed)

    if old_sequences is None:
        old_sequences = []
    old_sequences = {_['id']: _ for _ in old_sequences}

    sequences = []
    time_start=time.time()
    # new loop for my data
    for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
        prompt = row[args.column] # need to change this back to prompt

        # tokenization for Llama-8B-Instruct:
        if model_name == 'llama-8b-instruct':
            messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ]
        # tokenization for Llama-8B-Instruct
            inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
            terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
            input_length = inputs.shape[1]
       
            generation_config = transformers.GenerationConfig(
    max_new_tokens=args.max_new_tokens,  
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=terminators
)
            dict_outputs = model.generate(
            inputs,
            num_beams=1,
            do_sample=False,
            generation_config=generation_config,
            output_hidden_states=True,
            return_dict_in_generate=True,
            output_scores=True
        )
        elif "mistral" in model_name:
            prompt = reformat_prompt(prompt)
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
            inputs = inputs["input_ids"]
            input_length = input_ids.shape[1]
            generation_config = transformers.GenerationConfig(
    max_new_tokens=256,  # change this back to 256
    pad_token_id=tokenizer.eos_token_id
)           
            dict_outputs = model.generate(
            inputs,
            attention_mask=inputs["attention_mask"],
            num_beams=1,
            do_sample=False,
            generation_config=generation_config,
            output_hidden_states=True,
            return_dict_in_generate=True,
            output_scores=True
        )



        elif "qwen" in model_name:
            text = tokenizer.apply_chat_template(
            messages,
             enable_thinking=args.thinking,
             tokenize = False,
             add_generation_prompt = True
        )
            inputs = tokenizer([text], return_tensors="pt").to(model.device)
            input_length = inputs.input_ids.shape[1]
        
            generation_config = transformers.GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        pad_token_id=tokenizer.eos_token_id
    )
            dict_outputs = model.generate(
                **inputs,
                num_beams=1,
                do_sample=False,
                generation_config=generation_config,
                output_hidden_states=True,
                return_dict_in_generate=True,
                output_scores=True
            )
                
        
        scores = dict_outputs.scores
        perplexity = get_perplexity_score(scores)
        energy_score = get_energy_score(scores)
        most_likely_generations = dict_outputs.sequences.cpu()[0, input_length:]

        torch.cuda.empty_cache()
        generations = []
        num_gens = args.num_generations_per_prompt  
        while num_gens > 0:
            if model_name == "llama-8b-instruct" or "mistral" in model_name:
                dict_outputs = model.generate(
                    inputs,
                    num_beams=1,
                    num_return_sequences=min(10, num_gens),
                    do_sample=True,
                    top_p=0.99,
                    top_k=10,
                    temperature=args.temperature,
                    generation_config=generation_config,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                    output_scores=True
                )
            elif "qwen" in model_name:
                dict_outputs = model.generate(
                **inputs,
                num_beams=1,
                num_return_sequences=min(max_num_gen_once, num_gens),
                do_sample=True,
                top_p=args.top_p,
                top_k=args.top_k,
                temperature=args.temperature,
                generation_config=generation_config,
                output_hidden_states=True,
                return_dict_in_generate=True,
                output_scores=True
            )

            generation = dict_outputs.sequences[:, input_length:].cpu()
            generations.append(generation)
            num_tokens = get_num_tokens(generation)
            scores = dict_outputs.scores
            predictive_entropy = get_lenghthNormalized_entropy(scores, num_tokens)
            hidden_states = dict_outputs.hidden_states
            eigenIndicator, eigenValue = getEigenIndicator_v0(hidden_states, num_tokens)
            eigenIndicatorv2 = getEigenIndicator_v2(hidden_states, num_tokens)
            num_gens -= len(generation)

        generations = torch.nested.nested_tensor(generations).to_padded_tensor(tokenizer.eos_token_id)
        generations = generations.reshape(-1, generations.shape[-1])[:args.num_generations_per_prompt]
        best_generated_text = tokenizer.decode(most_likely_generations, skip_special_tokens=True)
        generated_texts = [tokenizer.decode(_, skip_special_tokens=True) for _ in generations]
        lexical_similarity = getLexicalSim(generated_texts)
        eigenIndicatorOutput, eigenValue_O = getEigenIndicatorOutput(generated_texts, SenSimModel)

        # semantic entropy calculations
        # Create ModelWrapper for semantic entropy calculations
        model_wrapper = ModelWrapper(model, tokenizer, model_name, max_new_tokens=args.max_new_tokens)
        responses = generate_responses(model_wrapper, prompt, args.num_generations_per_prompt, args.temperature)
        semantic_entropy = compute_semantic_entropy(responses, entailment_model, question=prompt)

        torch.cuda.empty_cache()
        # append the values
        perplexities.append(perplexity)
        energies.append(energy_score)
        entropies.append(predictive_entropy)
        lexical_similarities.append(lexical_similarity)
        eigenscores.append(eigenIndicator)
        eigenvalues.append(eigenValue)
        eigenscore_outputs.append(eigenIndicatorOutput)
        batch_generations.append(generated_texts)
        eigenindicatorsv2.append(eigenIndicatorv2)
        semantic_entropies.append(semantic_entropy)

       
    #update dfs
    df["perplexity"] = perplexities
    df["energy"] = energies
    df["normalized_entropy"] = entropies
    df["lexical_similarity"] = lexical_similarities
    df["eigenscore_original"] = eigenscores
    df["eigenscore_output"] = eigenscore_outputs
    df["batch_generations"] = batch_generations 
    df["eigenscore_variant"] = eigenindicatorsv2  
    df['semantic_entropy'] = semantic_entropies

    
    # return sequences
    return df


def get_num_tokens(generation):  # generation: num_seq x max(num_tokens)
    num_tokens = []
    for ids in generation:
        count = 0
        for id in ids:
            if id>2:
                count+=1
        num_tokens.append(count+1)
    return num_tokens


def main(overwrite=False, continue_from=None, parallel:int=None):
    if continue_from:
        fname = os.path.basename(continue_from)
        args.__dict__ = utils.jload(continue_from.replace(fname, 'args'+fname.replace("_partial.pkl", ".json")))
        old_sequences = pd.read_pickle(continue_from)
        cache_dir = os.path.dirname(continue_from)
        run_id = int(os.path.basename(continue_from).replace("_partial.pkl", ""))
        model_name = args.model
    else:
        old_sequences = []
        model_name = args.model
        if '/' in model_name:
            model_name = model_name.replace('/', '_')
        cache_dir = os.path.join(_settings.GENERATION_FOLDER, f'{model_name}_{args.project_ind}')
        os.makedirs(cache_dir, exist_ok=True)
        old_results = glob.glob(os.path.join(cache_dir, '*.pkl'))
        old_results = [_ for _ in old_results if '_partial' not in _]
        if len(old_results) > 0 and not overwrite:
            print(f'Found {len(old_results)} generations in {cache_dir}.')
            return
        run_id = len(old_results)
        with open(os.path.join(cache_dir, f'args{run_id}.json'), 'w') as f:
            json.dump(args.__dict__, f)
    print(f'Generating {args.num_generations_per_prompt} generations per prompt for {model_name} ...')
    print(f"Saving to {os.path.join(cache_dir, f'{run_id}.pkl')}")
    df = get_generations(model_name, args, seed=args.seed, old_sequences=old_sequences)
    df.to_csv(args.output, index=False)
    return

if __name__ == '__main__':
    task_runner = main(parallel=args.nprocess)
