from typing import List, Tuple, Dict
import os
import copy
import re
import json
import argparse
from tqdm import tqdm
from collections import Counter, defaultdict

from rich.console import Console

from utils import fixed_seed, normalize_answer


console = Console()

class TokenF1Calculator:
    
    def _prec_recall_f1_score(pred_items, gold_items):
        common = Counter(gold_items) & Counter(pred_items)
        num_same = sum(common.values())
        if num_same == 0:
            return 0, 0, 0
        
        precision = 1.0 * num_same / len(pred_items)
        recall = 1.0 * num_same / len(gold_items)
        f1 = (2 * precision * recall) / (precision + recall)
        return precision, recall, f1

    def compute(pred: str, context: List[str]) -> Tuple[float, float, float]:
        pred_tokens = normalize_answer(pred).split()
        scores = [
            TokenF1Calculator._prec_recall_f1_score(pred_tokens, normalize_answer(c).split())
            for c in context
        ]

        precision = [ele[0] for ele in scores]
        recall = [ele[1] for ele in scores]
        f1 = [ele[2] for ele in scores]
        
        max_p = max(precision)
        max_p_index = precision.index(max_p)
        
        max_r = max(recall)
        max_r_index = recall.index(max_r)

        max_f1 = max(f1)
        max_f1_index = f1.index(max_f1)
        #max_p, max_r, max_f1 = 0, 0, 0
        #for i, (p, r, f1) in scores.items():
        #    max_p, max_r, max_f1 = max(max_p, p), max(max_r, r), max(max_f1, f1)
        
        return max_p, max_r, max_f1, max_p_index, max_r_index, max_f1_index

class TurnF1Calculator:

    def compute(pred_indices: List[int], gold_index: int) -> Tuple[float, float, float, int]:
        #assert len(pred_indices) == len(set(pred_indices)), "{}".format(pred_indices)
        print(pred_indices, gold_index)
        if gold_index not in pred_indices:
            return 0, 0, 0, 0
        
        common_value = gold_index
        num_same = 1
        precision = 1.0 * num_same / len(pred_indices)
        recall = 1.0 * num_same / 1
        f1 = (2 * precision * recall) / (precision + recall)
        print(precision, recall, f1)
        return precision, recall, f1, 1

def parse_args():
    parser = argparse.ArgumentParser(description="evaluating the ability of image-sharing behavior")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--template_type", type=str, default="image-sharing-v2")
    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--eval_result_dir", type=str, default="./parsed_results")
    parser.add_argument("--file_version", type=str, default="v1.0")
    parser.add_argument("--report_save_dir", type=str, default=None)
    parser.add_argument("--datatype", type=str, default="test")
    parser.add_argument("--rounding_step", type=int, default=0)
    parser.add_argument("--limit_num", type=int, default=-1)

    return parser.parse_args()

def load_json(datadir):
    with open(datadir, 'r') as f:
        return json.load(f)

def check_clear_parsed_count(results, model_name=None, rounding_step=1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"

    ret = []
    for instance in tqdm(results, total=len(results)):
        parsed_result = instance[f"{parse_string}_parsed_result"]

        if len(parsed_result) == 0:
            ret.append(0)
        else:
            ret.append(1)
    
    return round(100*sum(ret)/len(ret), 3)

def check_attribute_count(results, attribute=None, model_name=None, rounding_step=1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"

    # check generated rationale count among the clearly parsed samples
    ret = []
    for instance in tqdm(results, total=len(results)):
        parsed_result = instance[f"{parse_string}_parsed_result"]
        if len(parsed_result) == 0:
            continue
        
        cnt = 0
        for ele in parsed_result:
            if ele[attribute] == None:
                continue
            else:
                cnt += 1

        ret.append(cnt)
    
    return sum(ret), round(sum(ret)/len(ret), 3), min(ret), max(ret)

def find_max_metric(p_val, r_val):
    max_val = max([p_val, r_val])
    max_index = [p_val, r_val].index(max_val)
    index_map = {0: 'precision', 1: 'recall'}
    return index_map[max_index]

def annotate(results, model_name=None, rounding_step=1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"

    ret = []
    for instance in tqdm(results, total=len(results)):
        parsed_result = instance[f"{parse_string}_parsed_result"]
        all_context = instance["all_context"]

        if len(parsed_result) == 0:
            ret.append(instance)
            print(parsed_result, parse_string)
            
            print("annotate::no parsed results")
            continue
        
        tmp_ret = []
        for ele in parsed_result:
            utter = ele["utterance"]

            p, r, f1, p_index, r_index, f1_index = TokenF1Calculator.compute(utter, all_context)

            ele["precision"] = p
            ele["precision_index"] = p_index
            ele["recall"] = r
            ele["recall_index"] = r_index
            ele["f1"] = f1
            ele["f1_index"] = f1_index

            tmp_ret.append(ele)
        instance[f"{parse_string}_parsed_result"] = tmp_ret
        ret.append(instance)
    
    return ret

def check_turn_accuracy(results, model_name=None, rounding_step=1, limit_num=-1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"
    correct_cnt = 0
    total_cnt = 0
    filter_cnt = 0

    corrected_results = []
    final_results = []
    for instance in tqdm(results, total=len(results)):
        gold_index = instance["image_share_turn_idx"]-1
        total_cnt += 1

        if len(instance[f"{parse_string}_parsed_result"]) == 0: #[f"{parse_string}_parsed_result"]
            print('check_turn_accuracy::', instance[f"{parse_string}_parsed_result"])
            continue
        
        pred_indices = []
        tmp_results = []
        for t, ele in enumerate(instance[f"{parse_string}_parsed_result"]): #[f"{parse_string}_parsed_result"]
            if t == limit_num:
                break
            utter = ele["utterance"]
            max_metric = find_max_metric(ele["precision"], ele["recall"])
            max_value = ele[f'{max_metric}']
            if max_value >= 0.6:
                index = ele[f"{max_metric}_index"]
                pred_indices.append(index)

                copy_e = copy.deepcopy(ele)
                copy_e["turn_index"] = index
                tmp_results.append(copy_e)
            else:
                for i, e_utt in enumerate(instance["all_context"]):
                    print(i, e_utt)

                print('check_turn_accuracy::', utter)

        new_instance = copy.deepcopy(instance)
        new_instance[f"{parse_string}_annotated_parsed_results"] = tmp_results
        final_results.append(new_instance)

        # check turn accuracy
        if gold_index in pred_indices:
            correct_cnt += 1
            corrected_results.append(new_instance)
        else:
            filter_cnt += 1
            #print(instance["prompt_input"])
            #print(gold_index, pred_indices)
            #for ele in instance[f"{parse_string}_parsed_result"]:
            #    print(ele)
            #print()
            '''print(gold_index, pred_indices)
            print(instance["all_context"])
            print(instance["prompt_input"])
            print(instance["image_share_turn_idx"])
            print()
            for ele in instance[f"{parse_string}_parsed_result"]:
                print(ele)
            print()'''
    
    assert total_cnt == len(results)
    return correct_cnt, total_cnt, round(100*correct_cnt/total_cnt, 3), corrected_results, final_results

def remove_unreasonable_sample(results, model_name=None, rounding_step=1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"

    ret = []
    before_cnt = 0
    after_cnt = 0
    pattern = re.compile(r'The following is a dialogue between (?P<spk1>.*?) and (?P<spk2>.*?)\..*')
    for instance in tqdm(results, total=len(results)):
        prompt_input = instance["prompt_input"]

        match = pattern.search(prompt_input)
        spk1 = match.group("spk1")
        spk2 = match.group("spk2")

        parsed_result = instance[f"{parse_string}_parsed_result"]

        if len(parsed_result) == 0:
            ret.append(instance)
            continue
        
        tmp_ret = []
        for ele in parsed_result:
            before_cnt += 1
            if (spk1 in ele["utterance"]) or (spk2 in ele["utterance"]):
                print(parsed_result, 'hoy')
                continue
            if ele["speaker"] not in [spk1, spk2]:
                print()
                print()
                print(instance["prompt_input"])
                print(instance[f"{parse_string}_openai_resp"])
                print(parsed_result, 'hap')
                continue
            tmp_ret.append(ele)
            after_cnt += 1
        
        instance[f"{parse_string}_parsed_result"] = tmp_ret
        ret.append(instance)
        
    console.log(f"[bold green]Before #: {before_cnt}, After #: {after_cnt}")

    return ret

def aggregate_rationale(results, model_name=None, rounding_step=1):
    if "text" in model_name:
        parse_string = "task1"
    else:
        parse_string = f"{rounding_step}_task1"

    rationales = []
    for instance in tqdm(results, total=len(results)):
        parsed_result = instance[f"{parse_string}_parsed_result"]
        if len(parsed_result) == 0:
            print("aggregate_rationale::No parsed results")
            continue

        for ele in parsed_result:
            rationale = ele["rationale"]
            if not rationale:
                #print("aggregate_rationale::no rationale")
                continue
            
            rationales.append(rationale)
    return rationales

def main():
    args = parse_args()
    fixed_seed(args.seed)

    if "text" in args.model_name:
        eval_result_dir = os.path.join(
            args.eval_result_dir, 
            args.file_version, 
            args.model_name,
            "{}_parsed_result.json".format(args.template_type)
        )
    else:
        eval_result_dir = os.path.join(
            args.eval_result_dir, 
            args.file_version, 
            args.model_name,
            "{}_{}_parsed_result.json".format(args.rounding_step, args.template_type)
        )

    results = load_json(eval_result_dir)

    filtered_results = remove_unreasonable_sample(results, args.model_name, args.rounding_step)
    annotated_results = annotate(filtered_results, args.model_name, args.rounding_step)
    crr_cnt, tot_cnt, turn_acc, corr_results, new_results = check_turn_accuracy(annotated_results, args.model_name, args.rounding_step, args.limit_num)
    rationale = aggregate_rationale(annotated_results, args.model_name, args.rounding_step)
    
    #assert False
    clear_parsed_ratio = check_clear_parsed_count(results, args.model_name, args.rounding_step)
    utt_cnt, avg_utt_per_dialog, utt_min, utt_max = check_attribute_count(results, attribute="utterance", model_name=args.model_name, rounding_step=args.rounding_step)
    spk_cnt, avg_spk_per_dialog, spk_min, spk_max = check_attribute_count(results, attribute="speaker", model_name=args.model_name, rounding_step=args.rounding_step)
    r_cnt, avg_r_per_dialog, r_min, r_max = check_attribute_count(results, attribute="rationale", model_name=args.model_name, rounding_step=args.rounding_step)

    r_ratio = round(100*r_cnt/utt_cnt, 3) # % of rationale generation w.r.t. utterance generation
    report_string = [
        args.model_name, clear_parsed_ratio, crr_cnt, tot_cnt, turn_acc,
        utt_cnt, avg_utt_per_dialog, utt_min, utt_max,
        spk_cnt, avg_spk_per_dialog, spk_min, spk_max,
        r_cnt, avg_r_per_dialog, r_min, r_max, r_ratio,
    ]
    report_string = [str(ele) for ele in report_string]

    report_save_dir = os.path.join('./reports', args.file_version, args.model_name, str(args.limit_num))
    os.makedirs(report_save_dir, exist_ok=True)

    if 'text' in args.model_name:
        with open(os.path.join(report_save_dir, f'{args.template_type}_report.txt'), 'w') as f:
            f.write('\t'.join(report_string))
        
        with open(os.path.join(report_save_dir, f'{args.template_type}_rationale.txt'), 'w') as f:
            f.write('\n'.join(rationale))

        with open(os.path.join(report_save_dir, f'{args.template_type}_results.json'), 'w') as f:
            json.dump(new_results, f, ensure_ascii=False, indent='\t')

    else:
        with open(os.path.join(report_save_dir, f'{args.rounding_step}_{args.template_type}_result.txt'), 'w') as f:
            f.write('\t'.join(report_string))
        
        with open(os.path.join(report_save_dir, f'{args.rounding_step}_{args.template_type}_rationale.txt'), 'w') as f:
            f.write('\n'.join(rationale))
        
        with open(os.path.join(report_save_dir, f'{args.rounding_step}_{args.template_type}_results.json'), 'w') as f:
            json.dump(new_results, f, ensure_ascii=False, indent='\t')



if __name__ == '__main__':
    main()