# based on the code from https://github.com/bigcode-project/bigcode-evaluation-harness
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, StoppingCriteria, StoppingCriteriaList
import torch
from datasets import load_dataset
from code_eval import compute_code_eval
from math import ceil
from collections import defaultdict
from typing import List, Optional
import argparse
from utils import HumanEvalTask, MBPPTask, EndOfFunctionCriteria, TooLongFunctionCriteria
from tqdm import tqdm
import os
import copy
import json
from multiprocessing import Pool
import numpy as np
from gph import ripser_parallel


def get_gen_len(generated_tokens, gen_code_processed, tokenizer):
    for k in range(len(generated_tokens)):
        tmp = tokenizer.decode(generated_tokens[:k], skip_special_tokens=True)
        if len(tmp) >= len(gen_code_processed):
            return k
    return len(generated_tokens)


def get_probs(prompt_len, outputs):
    probs = np.ones(outputs.sequences.shape)
    
    for pos in range(len(outputs.logits)):
        sm = torch.nn.Softmax(dim = 1)(outputs.logits[pos])
        
        for seq in range(outputs.sequences.shape[0]):
            token = outputs.sequences[seq][prompt_len + pos]
            probs[seq, prompt_len + pos] = sm[seq][token].item()
    return probs


def merge_attentions(attn):
    n_layers = len(attn[0])
    n_heads = attn[0][0].shape[1]
    seq_len = attn[-1][-1].shape[-1]
    n_seq = attn[0][0].shape[0]
    total_attn = np.zeros((n_seq, n_layers, n_heads, seq_len, seq_len))
    for i in range(len(attn)):
        attn_step = np.stack([x.cpu().squeeze().float().numpy() for x in attn[i]], axis=1)
        if i == 0:
            prompt_len = attn_step.shape[-1]
            total_attn[:, :, :, :prompt_len, :prompt_len] = attn_step
        else:
            n_elem = attn_step.shape[-1]
            total_attn[:, :, :, n_elem - 1, :n_elem] = attn_step
    return total_attn


def compute_mtd_matrix(attn_tensor, prompt_len, kind):
    # attn tensor of shape (n_layers, n_heads, n_tokens, n_tokens)
    # kind = 0: prompt to 0
    # kind = 1: response to 0
    n_layers = attn_tensor.shape[0]
    n_heads = attn_tensor.shape[1]
    mtd0_mat = np.zeros((n_layers, n_heads))
    mtd1_mat = np.zeros((n_layers, n_heads))

    for i in range(n_layers):
        for j in range(n_heads):
            D = copy.deepcopy(attn_tensor[i, j])
            D = (D + D.T)
            D = 1 - D
            if kind == 0:
                D[:prompt_len, :prompt_len] = 0.
            elif kind == 1:
                D[prompt_len:, prompt_len:] = 0.
            np.fill_diagonal(D, 0)
            
            r = ripser_parallel(D, metric = 'precomputed', n_threads = -1)
            dgm0 = r['dgms'][0]
            dgm1 = r['dgms'][1]
            mtd0 = np.sum(dgm0[dgm0 < np.inf])
            if dgm1.shape[0]:
                mtd1 = np.sum(dgm1[:, 1] - dgm1[:, 0])
            else:
                mtd1 = 0

            mtd0_mat[i, j] = mtd0
            mtd1_mat[i, j] = mtd1
            
    return mtd0_mat, mtd1_mat


def get_stopping_criteria(task, tokenizer):
    stopping_criteria = []
    if task.stop_words and tokenizer.eos_token:
        task.stop_words.append(tokenizer.eos_token)    
    if hasattr(task, "check_fn"):
        stopping_criteria.append(
            EndOfFunctionCriteria(0, task.stop_words, tokenizer, task.check_fn)
        )
    elif task.stop_words:
        stopping_criteria.append(
            EndOfFunctionCriteria(0, task.stop_words, tokenizer)
        )
    if hasattr(task, "max_length_multiplier") and task.max_length_multiplier:
        stopping_criteria.append(
            TooLongFunctionCriteria(0, task.max_length_multiplier)
        )
    return stopping_criteria


def generate_task(model, tokenizer, task, dataset, sample_num, args, pool=None):
    if args.task == "humaneval":
        gen_kwargs = {
            "do_sample": args.do_sample,
            "temperature": args.temperature,
            "top_p": args.top_p,
            "max_length": args.max_length_generation,
            "max_new_tokens": None,
        }
    elif args.task == "mbpp":
        gen_kwargs = {
            "do_sample": args.do_sample,
            "temperature": args.temperature,
            "top_p": args.top_p,
            "max_length": None,
            "max_new_tokens": args.max_length_generation,
        }
    stopping_criteria = get_stopping_criteria(task, tokenizer)
    if stopping_criteria:
        gen_kwargs["stopping_criteria"] = StoppingCriteriaList(stopping_criteria)

    prompt_contents = task.get_prompt(dataset[sample_num])
    prompt = args.prefix + prompt_contents
    if args.task == "humaneval":
        tokenized_prompt = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt", max_length=args.max_length_generation)
    elif args.task == "mbpp":
        tokenized_prompt = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        if task.stop_words:
            prompt_len = tokenized_prompt['input_ids'].shape[-1]
            gen_kwargs["stopping_criteria"][0].start_length = prompt_len
        inputs = tokenized_prompt['input_ids'].to(model.device)
        outputs = model.generate(
            input_ids=inputs,
            attention_mask=tokenized_prompt['attention_mask'].to(model.device),
            num_return_sequences=args.batch_size,
            output_logits=args.compute_features,
            output_attentions=args.compute_features,
            return_dict_in_generate=args.compute_features,
            **gen_kwargs,
        )
            
    code_gens = []
    generated_examples = []
    generated_tokens = outputs['sequences'].cpu().numpy() if args.compute_features else outputs.cpu().numpy()
    for i, s in enumerate(generated_tokens):
        if tokenizer.eos_token in task.stop_words:
            if s[0] == tokenizer.bos_token_id:
                s = s[1:]
            gen_code = tokenizer.decode(s, skip_special_tokens=False, clean_up_tokenization_spaces=False)
            gen_code_processed = task.postprocess_generation(gen_code, sample_num)
            code_gens.append(gen_code_processed)

        new_example = copy.deepcopy(dataset[sample_num])
        new_example['prompt'] = prompt
        new_example['output'] = tokenizer.decode(generated_tokens[i], skip_special_tokens=True)
        new_example['prompt_len'] = prompt_len
        new_example['generation'] = gen_code_processed
        if args.task == "humaneval":
            new_example['gen_len'] = get_gen_len(generated_tokens[i], gen_code_processed, tokenizer)
        elif args.task == "mbpp":
            valid_code = task.get_valid_output(gen_code, sample_num)
            new_example['valid_code'] = valid_code
            new_example['gen_len'] = get_gen_len(generated_tokens[i], valid_code, tokenizer)
            new_example['out_check'] = tokenizer.decode(generated_tokens[i, :new_example['gen_len']], skip_special_tokens=True)
        generated_examples.append(new_example)

    if args.compute_features:
        probs = get_probs(prompt_len, outputs)
        merged_attn = merge_attentions(outputs['attentions'])

        mtd0_mat_kind0_list = []
        mtd1_mat_kind0_list = []
        mtd0_mat_kind1_list = []
        mtd1_mat_kind1_list = []
        for i in range(args.batch_size):
            attn_i = merged_attn[i, :, :, :generated_examples[i]['gen_len'], :generated_examples[i]['gen_len']]
            splitted_attn = np.array_split(attn_i, args.num_of_workers)
            # kind = 0
            mtd_args = [(a, prompt_len, 0) for a in splitted_attn]
            mtd_mat_part = pool.starmap(compute_mtd_matrix, mtd_args)
            mtd0_matrix_kind0 = np.vstack([m[0] for m in mtd_mat_part])
            mtd0_mat_kind0_list.append(mtd0_matrix_kind0)
            mtd1_matrix_kind0 = np.vstack([m[1] for m in mtd_mat_part])
            mtd1_mat_kind0_list.append(mtd1_matrix_kind0)
            # kind = 1
            mtd_args = [(a, prompt_len, 1) for a in splitted_attn]
            mtd_mat_part = pool.starmap(compute_mtd_matrix, mtd_args)
            mtd0_matrix_kind1 = np.vstack([m[0] for m in mtd_mat_part])
            mtd0_mat_kind1_list.append(mtd0_matrix_kind1)
            mtd1_matrix_kind1 = np.vstack([m[1] for m in mtd_mat_part])
            mtd1_mat_kind1_list.append(mtd1_matrix_kind1)

        diag_prompt = []
        diag_response = []
        for i in range(args.batch_size):
            attn_i = merged_attn[i, :, :, :generated_examples[i]['gen_len'], :generated_examples[i]['gen_len']]
            diag_prompt.append(
                np.stack([np.sum(np.diag(attn_i[k, j])[:prompt_len]) for k in range(attn_i.shape[0]) for j in range(attn_i.shape[1])])
            )
            diag_response.append(
                np.stack([np.sum(np.diag(attn_i[k, j])[prompt_len:]) for k in range(attn_i.shape[0]) for j in range(attn_i.shape[1])])
            )

        del outputs
        torch.cuda.empty_cache()

        return code_gens, generated_examples, probs, diag_prompt, diag_response, mtd0_mat_kind0_list, mtd1_mat_kind0_list, mtd0_mat_kind1_list, mtd1_mat_kind1_list

    del outputs
    torch.cuda.empty_cache()
            
    return code_gens, generated_examples


def generate_main(args):
    dict_precisions = {"fp16": torch.float16, "bf16": torch.bfloat16,}
    if args.precision not in dict_precisions:
        raise ValueError(
            f"Non valid precision {args.precision}, choose from: fp16, bf16"
        )
    model_kwargs = {}
    print(f"Loading model in {args.precision}")
    model_kwargs["torch_dtype"] = dict_precisions[args.precision]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.compute_features:
        model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs, attn_implementation="eager")
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
    model = model.to(device)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    if not tokenizer.eos_token:
        if tokenizer.bos_token:
            tokenizer.eos_token = tokenizer.bos_token
            print("bos_token used as eos_token")
        else:
            raise ValueError("No eos_token or bos_token found")
    try:
        tokenizer.pad_token = tokenizer.eos_token
    except AttributeError:
        print("Not setting pad_token to eos_token")
        pass
    new_line_arg = 'qwen' in args.model.lower()
    print ('New line: ', new_line_arg)
    if args.task == "humaneval":
        task = HumanEvalTask(new_line_arg)
    elif args.task == "mbpp":
        task = MBPPTask(new_line_arg, args.two_shot)
    dataset = task.get_dataset()
    n_tasks = len(dataset)
    references = [task.get_reference(dataset[i]) for i in range(n_tasks)]

    os.makedirs('/'.join(args.gen_path.split('/')[:-1]), exist_ok=True)

    if args.compute_features:
        pool = Pool(args.num_of_workers)
        os.makedirs(args.probs_path, exist_ok=True)
        os.makedirs(args.diag_prompt_path, exist_ok=True)
        os.makedirs(args.diag_response_path, exist_ok=True)
        os.makedirs(args.mtd0_kind0_path, exist_ok=True)
        os.makedirs(args.mtd1_kind0_path, exist_ok=True)
        os.makedirs(args.mtd0_kind1_path, exist_ok=True)
        os.makedirs(args.mtd1_kind1_path, exist_ok=True)

    generations: List[List[Optional[str]]] = [[] for _ in range(n_tasks)]
    for sample_num in tqdm(range(n_tasks), desc='Generating'):
        task_id = dataset[sample_num]['task_id'].replace('/', '_') if args.task == "humaneval" else dataset[sample_num]['task_id']
        for i in range(args.start_batch, args.start_batch + args.n_batches):
            set_seed(i)
            if args.compute_features:
                generations[sample_num], gen_examples, probs, diag_prompt, diag_response, mtd0_mat_kind0_list, mtd1_mat_kind0_list, mtd0_mat_kind1_list, mtd1_mat_kind1_list = generate_task(model, tokenizer, task, dataset, sample_num, args, pool)

                for j in range(args.batch_size):
                    gen_num = i * args.batch_size + j
                    task_id = gen_examples[j]['task_id'].replace('/', '_') if args.task == "humaneval" else gen_examples[j]['task_id']
                    filename = 'task_id_{}_seed_{}_gen_num_{}'.format(task_id, i, gen_num)
                    
                    # save probs
                    np.save(os.path.join(args.probs_path, filename), probs[j])
                    # save diag_prompt
                    np.save(os.path.join(args.diag_prompt_path, filename), diag_prompt[j])
                    # save diag_response
                    np.save(os.path.join(args.diag_response_path, filename), diag_response[j])
    
                    # save mtd0_kind0
                    np.save(os.path.join(args.mtd0_kind0_path, filename), mtd0_mat_kind0_list[j])
                    # save mtd1_kind0
                    np.save(os.path.join(args.mtd1_kind0_path, filename), mtd1_mat_kind0_list[j])
        
                    # save mtd0_kind1
                    np.save(os.path.join(args.mtd0_kind1_path, filename), mtd0_mat_kind1_list[j])
                    # save mtd1_kind1
                    np.save(os.path.join(args.mtd1_kind1_path, filename), mtd1_mat_kind1_list[j])
            
            else:
                generations[sample_num], gen_examples = generate_task(model, tokenizer, task, dataset, sample_num, args)
            
            for j, gen_ex in enumerate(gen_examples):
                gen_num = i * args.batch_size + j
                gen_examples[j]['gen_num'] = gen_num
                gen_examples[j]['seed'] = i

            # save gen_examples
            with open(args.gen_path, 'a+', encoding='utf-8') as fw:
                for ex in gen_examples:
                    fw.write(json.dumps(ex) + '\n')
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--prefix', type=str, default="")
    parser.add_argument('--do_sample', action="store_true")
    parser.add_argument('--temperature', type=float, default=0.8)
    parser.add_argument('--top_p', type=float, default=0.95)
    
    parser.add_argument('--model', type=str, help="model name or path")
    parser.add_argument('--task', type=str, default="humaneval")
    parser.add_argument('--max_length_generation', type=int, default=512)
    parser.add_argument('--precision', type=str, default='fp16')

    parser.add_argument('--compute_features', action="store_true")
    parser.add_argument('--gen_path', type=str, help="output path of your generation")
    parser.add_argument('--probs_path', type=str, help="output path of generation probs")
    parser.add_argument('--diag_prompt_path', type=str, help="output path of generation diag prompt sum")
    parser.add_argument('--diag_response_path', type=str, help="output path of generation diag response sum")
    parser.add_argument('--mtd0_kind0_path', type=str, help="output path of generation mtd0 kind 0")
    parser.add_argument('--mtd1_kind0_path', type=str, help="output path of generation mtd1 kind 0")
    parser.add_argument('--mtd0_kind1_path', type=str, help="output path of generation mtd0 kind 1")
    parser.add_argument('--mtd1_kind1_path', type=str, help="output path of generation mtd1 kind 1")
    
    parser.add_argument('--batch_size', type=int, help="number of sequences to generate")
    parser.add_argument('--n_batches', type=int, help="number of batches per task")
    parser.add_argument('--start_batch', type=int, help="number of start batch")
    parser.add_argument('--num_of_workers', type=int, help="number of workers")
    parser.add_argument('--two_shot', action="store_true", help="two-shot for MBPP")
    args = parser.parse_args()
    print (args)

    generate_main(args)
