import json
import torch
import mauve 
import argparse
import numpy as np
import tiktoken
import csv


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gen", type=str, required=True, help="generated texts file")
    parser.add_argument("--gt", type=str, required=True, help="ground truth texts file")
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--fix_len", action="store_true", help="evaluate samples with fix length")
    parser.add_argument("--csv", action="store_true", help="texts file stored in csv")
    parser.add_argument("--device_id", type=int, default=2)
    parser.set_defaults(csv=True)
    parser.set_defaults(fix_len=False)
    args = parser.parse_args()
    return args

def decode(tokens, tokenizer):
    token_id_list = tokenizer.convert_tokens_to_ids(tokens)
    text = tokenizer.decode(token_id_list)
    return text

def parse_text(reference_text, prediction_text, tokenizer, max_length=128):
    reference_tokens = tokenizer.encode(reference_text)
    prediction_tokens = tokenizer.encode(prediction_text)
    min_len = min(len(reference_tokens), len(prediction_tokens))
    reference_tokens = reference_tokens[:max_length]
    reference_text = tokenizer.decode(reference_tokens)
    prediction_tokens = prediction_tokens[:max_length]
    prediction_text = tokenizer.decode(prediction_tokens)
    
    if min(len(reference_tokens), len(prediction_tokens)) == max_length:
        flag = True
    else:
        flag = False
    return reference_text, prediction_text, flag

def load_csv(filename):
    data = []
    with open(filename, "r") as f:
        reader = csv.reader(f, delimiter=",")
        for line in reader:
            data.append(line[0])
    return data

def load_result(in_f, tokenizer):
    with open(in_f) as f:
        result_list = json.load(f)

    # load reference list
    reference_list = []
    for item in result_list:
        one_reference_text = item['reference_text']
        reference_list.append(one_reference_text)

    # load all predictions
    number_of_predictions_per_instance = len(result_list[0]['generated_result'])
    print ('Number of predictions per instance is {}'.format(number_of_predictions_per_instance))
    all_prediction_list = []
    for idx in range(number_of_predictions_per_instance):
        one_prediction_list = []
        for item in result_list:
            one_prediction = item['generated_result'][str(idx)]
            one_prediction_list.append(one_prediction)
        assert len(one_prediction_list) == len(reference_list)
        all_prediction_list.append(one_prediction_list)
    return reference_list, all_prediction_list

def evaluate_one_instance(reference_list, prediction_list, tokenizer, max_length, fix_len_flag, device_id):
    ref_list, pred_list = [], []
    data_num = len(reference_list)
    for idx in range(data_num):
        one_ref, one_pred = reference_list[idx], prediction_list[idx]
        one_ref, one_pred, flag = parse_text(one_ref, one_pred, tokenizer, max_length=max_length)
        if fix_len_flag:
            if flag:
                pass
            else:
                continue
        if len(one_pred.strip()) > 0: # igore predictions with zero length
            ref_list.append(one_ref)
            pred_list.append(one_pred)
            
    # use gpt2 model as the based model based on the author's implementation:
    # https://github.com/XiangLi1999/ContrastiveDecoding/blob/98cad19349fb08ee95b0f25a661179866f8e2c84/text-generation/eval_script.py#L248
    out =  mauve.compute_mauve(p_text=ref_list, q_text=pred_list, device_id=device_id, verbose=False,
        featurize_model_name=EVAL_GPT2_PATH, mauve_scaling_factor=1.0)
    mauve_score = out.mauve
    return mauve_score*100

def measure_mauve(gen_filename, gt_filename, max_length, fix_len_flag, device_id):
    from transformers import AutoTokenizer
    

    #tokenizer = AutoTokenizer.from_pretrained(EVAL_GPT2_PATH)
    #reference_list, all_prediction_list = load_result(in_f, tokenizer)
    tokenizer = tiktoken.get_encoding("gpt2")
    reference_list = load_csv(gt_filename)
    all_prediction_list = [load_csv(gen_filename)]


    mauve_score_list = []
    for idx in range(len(all_prediction_list)):
        one_prediction_list = all_prediction_list[idx]
        one_mauve_score = evaluate_one_instance(reference_list, one_prediction_list, tokenizer, max_length, fix_len_flag, device_id)
        mauve_score_list.append(one_mauve_score)

    mean, std = round(np.mean(mauve_score_list),2), round(np.std(mauve_score_list),2)
    result_dict = {
        "mauve_score_list": [str(num) for num in mauve_score_list],
        'mauve_mean': str(mean),
        'mauve_std': str(std)
    }
    return result_dict

def main():
    args = parse_args()
    print(args)
    result = measure_mauve(args.gen, args.gt, args.max_length, args.fix_len, args.device_id)
    print(result)

if __name__ == "__main__":
    main()

