import time
import os
import logging
import argparse
import random
import json
import torch
import numpy as np
import pandas as pd

from transformers import GenerationConfig
from pynvml import *
import psutil
from scripts.llama_full import *
from utils.misc_utils import *

parser = argparse.ArgumentParser(description='PyTorch roberta finetuning')

parser.add_argument('--dataset_name',
                    default="",
                    type=str,
                    help='Dataset name (glue, raft etc.)')
parser.add_argument('--task',
                    default=None,
                    type=str,
                    help='finetuning task')
parser.add_argument('--model_name',
                    type=str,
                    help='path to load the model from')
parser.add_argument('--use_quantized',
                    action="store_true",
                    help='whether to load quantized base Llama model'
                    )
parser.add_argument('--split',
                    default="valid",
                    type=str,
                    help='valid/test split to evaluate the performance')
parser.add_argument('--max_seq_len',
                    type=int,
                    default=512,
                    help='max sequence length for tokenizer'
                    )
parser.add_argument('--seed',
                    type=int,
                    default=123,
                    help='rand seed'
                    )
parser.add_argument('--data_size',
                    type=int,
                    default=None,
                    help='data size this model was trained on; used for logging results'
                    )
parser.add_argument('--r',
                    type=int,
                    help='rank this model was trained on; used for logging results'
                    )
parser.add_argument('--density',
                    type=float,
                    help='density spiel was trained on; used for logging results'
                    )
parser.add_argument('--result_path',
                    type=str,
                    default="",
                    help='path to log the result to one file.'
                    )
parser.add_argument('--use_flash_attention',
                    action='store_true',
                    help='whether to use flash attentionv2 for inference'
                    )
parser.add_argument('--use_safetensor',
                    action='store_true',
                    help='whether to load the model in safetensor or not'
                    )
parser.add_argument('--on_vector',
                    action="store_true"
                    )

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = logging.getLogger("Llama-eval")
curr_datetime = time.strftime("%Y%m%d%H%M%S")

logging.basicConfig(filename=f"../logs/llama_eval_{args.task}_{args.split}_{curr_datetime}.log",
                    format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %I:%M:%S',
                    level=logging.INFO)
logger.info(f"""Initializing finetuning script with the following params:
            - model_name {args.model_name}
            - dataset_name {args.dataset_name}
            - task {args.task}
            - use_quantized: {args.use_quantized}
            - split: {args.split}
            - data_size: {args.data_size}
            """
            )


def generate_response(model, tokenizer, answer_len, user_input):
    # assume user_input already has chat template applied to
    response = None

    prompt = user_input.split('<|start_header_id|>assistant<|end_header_id|>')[0] # remove the answer portion of the input
    prompt += "\n     <|start_header_id|>assistant<|end_header_id|>"
    
    if answer_len == "short":
        generation_config = GenerationConfig(top_k=2,
                                                temperature=0.1,
                                                max_new_tokens=30,
                                                pad_token_id=tokenizer.eos_token_id,
                                                num_beam=4,
                                                #do_sample=True
        )
    elif answer_len == "long":
        generation_config = GenerationConfig(top_k=2,
                                                 temperature=0.1,
                                                 max_new_tokens=500,
                                                 pad_token_id=tokenizer.eos_token_id,
                                                 num_beam=4,
                                                 #do_sample=True
         )

    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = model.generate(**inputs, generation_config=generation_config)
        response = (tokenizer.decode(outputs[0], skip_special_tokens=True))

    return response


def eval_on_heldout(model, tokenizer, target_key, answer_len, data, split="valid"):

    d_res={}
    result=None

    for idx, usr_input in enumerate(data[split]['text']):
        out = generate_response(model=model, tokenizer=tokenizer, answer_len=answer_len, user_input=usr_input)
        target = data[split][target_key][idx]
        d_res[idx] = {"pred": out, "target": target}
        print(f"result: {out}, {target}")

    for k in d_res:
        d_res[k]['pred_clean'] = d_res[k]['pred'].split('assistant\n')[-1].strip().split('####')[-1].strip() # removed [:5]
    
    result = pd.DataFrame(d_res).T

    return result


def record_result(update, path):
    
    if not os.path.exists(path):
        with open(path,'w') as f:
            dummy = {"": []}
            json.dump(dummy, f, indent=4)
        f.close()

    with open(path,'r+') as f:
        res = json.load(f)
        key = list(update.keys())[0]
        if str(key) in res:
            res[str(key)] = update[key]
        else:
            res.update(update)
        f.seek(0)
        json.dump(res, f, indent=4)
    f.close()


def count_adapter_param(model, peft_method):
    
    kw=None
    if peft_method=='lora':
        kw = 'lora'
    elif peft_method=='spiel':
        kw = 'sft_delta'

    param=0
    for k in model.state_dict().keys():
        if kw in k:
            param += model.state_dict()[k].flatten().shape[0]
    
    return param


def main():
    
    torch.cuda.empty_cache()

    split = args.split
    task = args.task
    model_ckpt = args.model_name
    dataset_name = args.dataset_name
    on_vector = args.on_vector
    # for logging
    data_size = args.data_size
    max_seq_len = args.max_seq_len
    r = args.r
    density = args.density
    use_flash_attention = args.use_flash_attention
    use_safetensor = args.use_safetensor
    use_quantized = args.use_quantized
    max_seq_len = args.max_seq_len
    seed = args.seed
    result_path = args.result_path

    model_cls = Llama(model_name=model_ckpt,
                    dataset_name=dataset_name,
                    data_size=None,
                    task=task,
                    use_quantized=use_quantized,
                    use_peft=False,
                    r=None,
                    density=None,
                    peft_method=None,
                    seed=seed
                    )
    get_mem_usage_stats("At init stage")
    model, tokenizer = model_cls.get_model_and_tokenizer(use_flash_attention=use_flash_attention, use_safetensors=use_safetensor, on_vector=on_vector)
    print("Model device before: ", next(model.parameters()).device)
    model.eval()
    model.to(device)
    get_mem_usage_stats("After setting the model to device()")
    print("Model device after: ", next(model.parameters()).device)
    print(set([p.device for p in model.parameters()]))
    data = model_cls.get_data(tokenizer=tokenizer, max_seq_len=max_seq_len)
    
    # limit test data to <5000
    if data['test'].num_rows > 5000:
        data['test'] = data['test'].shuffle(seed=seed).select(range(5000))

    prompt = model_cls.load_task_prompt()
    target_key = prompt['assistant_output']
    answer_len = prompt['answer_len']

    get_mem_usage_stats("Before eval_result")
    eval_result = eval_on_heldout(model=model, tokenizer=tokenizer,
                                  target_key=target_key, answer_len=answer_len,
                                  data=data, split=split)
    print(eval_result.head())
    accuracy = (eval_result['target'].astype(str)==eval_result['pred_clean'].astype(str)).sum()/eval_result.shape[0]
    logger.info(f"Accuracy evaluation after training: {accuracy}")

    if result_path != "":
        if r:
            adapter_param_count = count_adapter_param(model, peft_method='lora')
            update = {r: {'accuracy': accuracy, 'trainable_params': adapter_param_count}}
        elif density:
            adapter_param_count = count_adapter_param(model, peft_method='spiel')/2
            update = {density: {'accuracy': accuracy, 'trainable_params': adapter_param_count}}
        else:
            update = {data_size: {'accuracy': accuracy}}
        record_result(update=update, path=result_path)


if __name__=='__main__':
    main()
