import os
import re
import json
import copy
import argparse
import pickle as pc
from tqdm import tqdm

import stanza
from rich.console import Console

from metrics import (
    AvgLenCalculator, 
    Dist1Calculator, 
    Dist2Calculator,
    Dist3Calculator,
    WordNet1Calculator,
    WordNet2Calculator,
    ObjectConsistencyCalculator,
    DialogueConsistencyCalculator,
    LenDistCalculator,
    Ent1Calculator,
    Ent2Calculator,
    Ent3Calculator,
    UniqueUniGramCalculator,
    UniqueBiGramCalculator,
    UniqueTriGramCalculator,
    POSCalculator
)


console = Console()
nlp = stanza.Pipeline(lang='en', processors='tokenize')

def load_json(datadir):
    with open(datadir, 'r', encoding='utf-8') as f:
        return json.load(f)

def parse_args():
    parser = argparse.ArgumentParser(description="evaluating generated captions")

    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument("--do_rebuilding", action="store_true")
    parser.add_argument("--rounding_step", type=int, default=0)

    return parser.parse_args()

def parse_caption(caption):
    caption = caption.strip()
    doc = nlp(caption)
    return [sentence.text for sentence in doc.sentences]

def deidentification(text, speakers):
    for spk in speakers:
        text = text.replace(spk, 'person')
    return text

def get_length(caption):
    return len(caption.split())

def rebuild_dataset(results, rounding_step=0):
    dataset = []
    for instance in tqdm(results, total=len(results)):
        new_instance = copy.deepcopy(instance)
        
        speakers = instance["task2_speakers"]
        
        golden_turn_index = instance["image_share_turn_idx"] - 1
        task1_annotated_result = instance["task1_annotated_result"]
        rationale = task1_annotated_result["rationale"]
        pred_turn_index = task1_annotated_result["turn_index"]

        if golden_turn_index == pred_turn_index:
            correct_sample = True
        else:
            correct_sample = False
        new_instance["correct_sample"] = correct_sample

        new_instance["org_rationale"] = rationale
        if rationale != None:
            rationale = deidentification(rationale, speakers)
        new_instance["deid_rationale"] = rationale

        if rounding_step == 0:
            gen_cap = instance["task2_openai_resp"]
        else:
            gen_cap = instance[f"{rounding_step}_task2_openai_resp"]
        #if not gen_cap.strip().startswith("An image of"):
        #    gen_cap = 'No answer'

        new_instance["org_caption"] = gen_cap.strip()

        gen_cap = deidentification(gen_cap, speakers)
        new_instance["deid_caption"] = gen_cap

        parsed_gen_cap = parse_caption(gen_cap)
        sub_gen_cap = parsed_gen_cap[0]
        new_instance["first_deid_caption"] = sub_gen_cap
        
        dataset.append(new_instance)

    save_dir = os.path.join('./final_dataset/for_model_gen', args.model_name)
    os.makedirs(save_dir, exist_ok=True)

    if rounding_step == 0:
        with open(os.path.join(save_dir, 'result.json'), 'w') as f:
            json.dump(dataset, f, ensure_ascii=False, indent='\t')
    else:
        with open(os.path.join(save_dir, f'{rounding_step}_result.json'), 'w') as f:
            json.dump(dataset, f, ensure_ascii=False, indent='\t')

def parse_gen(text):
    # The regex pattern
    pattern = r"(?<=an image of )(.*?)(?=\.|$)"

    text = text.strip()

    # Use the search function from the re module
    match = re.search(pattern, text, re.I)

    # Print the match
    if match:
        cap = match.group()
        cap = f"An image of {cap}"
    else:
        #print("No match found.")
        cap = "<NO_PARSED>"

    return cap


def rebuild_dataset_for_photochat(results, rounding_step=0):
    dataset = []
    for instance in tqdm(results, total=len(results)):
        new_instance = copy.deepcopy(instance)
        
        #speakers = instance["task2_speakers"]
        pattern = r'(?<=Dialogue:\n)([\s\S]*?)(?=\n\n)'
        match = re.search(pattern, instance["task2_prompt_input"])
        if match:
            dialogue = match.group(1)
        else:
            print("No match found")
        speakers = []
        for utter in dialogue.split('\n'):
            speakers.append(utter.split(': ')[0])
        speakers = list(set(speakers))
        
        golden_turn_index = instance["image_share_turn_idx"] - 1
        
        if rounding_step == 0:
            gen_cap = instance["task2_openai_resp"]
        else:
            gen_cap = instance[f"{rounding_step}_task2_openai_resp"]
        
        gen_cap = parse_gen(gen_cap)
        #if gen_cap == '':
        #    continue
        new_instance["org_caption"] = gen_cap
        gen_cap = deidentification(gen_cap, speakers)
        new_instance["deid_caption"] = gen_cap

        new_instance["first_deid_caption"] = gen_cap #parse(gen_cap) #sub_gen_cap
        
        dataset.append(new_instance)

    save_dir = os.path.join('./final_dataset/for_photochat', args.model_name, 'wo_restriction')
    os.makedirs(save_dir, exist_ok=True)

    if rounding_step == 0:
        with open(os.path.join(save_dir, 'result.json'), 'w', encoding='utf-8') as f:
            json.dump(dataset, f, ensure_ascii=False, indent='\t')
    else:
        with open(os.path.join(save_dir, f'{rounding_step}_result.json'), 'w', encoding='utf-8') as f:
            json.dump(dataset, f, ensure_ascii=False, indent='\t')

def main(args):
    
    if args.do_rebuilding:
        if 'text' in args.model_name:
            result_dir = os.path.join('./logs/task2_for_photochat/v6', args.model_name, 'test', '42', 'image-caption-generation-wo-restriction_generation.json')
        else:
            result_dir = os.path.join('./logs/task2_for_photochat/v6', args.model_name, 'test', '42', f'{args.rounding_step}_image-caption-generation-wo-restriction_generation.json')
        print(result_dir)
        results = load_json(result_dir)
        rebuild_dataset_for_photochat(results, args.rounding_step)

    if 'text' in args.model_name:
        result_dir = os.path.join('./final_dataset/for_photochat', args.model_name, 'wo_restriction', 'result.json')
    else:
        result_dir = os.path.join('./final_dataset/for_photochat', args.model_name, 'wo_restriction', f'{args.rounding_step}_result.json')
    results = load_json(result_dir)
    
    captions = [instance["first_deid_caption"].replace("An image of ", "") for instance in results] # if 'first_deid_caption' in instance.keys()]
    
    '''with open('./data/photochat/test.json', 'r') as f:
        base = json.load(f)
    
    captions = [
        instance["photo_description"].split('Objects in the photo: ')[-1]
        for instance in base
    ]'''

    pos_stat = POSCalculator().calculate(captions)
    hyp_1 = WordNet1Calculator().calculate(captions)
    hyp_2 = WordNet2Calculator().calculate(captions)
    avg_noncontradiction = DialogueConsistencyCalculator().calculate(results)
    avg_obj_recall = ObjectConsistencyCalculator().calculate(results)
    avglen = AvgLenCalculator().calculate(captions)
    lendist = LenDistCalculator().calculate(captions)
    dist_1 = Dist1Calculator().calculate(captions)
    dist_2 = Dist2Calculator().calculate(captions)
    dist_3 = Dist3Calculator().calculate(captions)
    ent_1 = Ent1Calculator().calculate(captions)
    ent_2 = Ent2Calculator().calculate(captions)
    ent_3 = Ent3Calculator().calculate(captions)
    unigram = UniqueUniGramCalculator().calculate(captions)
    bigram = UniqueBiGramCalculator().calculate(captions)
    trigram = UniqueTriGramCalculator().calculate(captions)
    

    report = [
        dist_1[0], dist_2[0], dist_3[0],
        ent_1[0], ent_2[0], ent_3[0],
        unigram, bigram, trigram,
        hyp_1, hyp_2, avglen, avg_obj_recall, avg_noncontradiction,
        pos_stat['NOUN'], pos_stat['PROPN'], pos_stat['VERB'], pos_stat['ADJ'], pos_stat['ADV']
    ]
    report = [str(round(ele, 4)) for ele in report]

    report_save_dir = f'./final_dataset/report/{args.model_name}/wo_restriction' #{args.model_name}
    os.makedirs(report_save_dir, exist_ok=True)

    with open(os.path.join(report_save_dir, 'report.txt'), 'w') as f:
        f.write('\t'.join(report))
    
    with open(os.path.join(report_save_dir, 'len_dist.pkl'), 'wb') as f:
        pc.dump(lendist, f)

if __name__ == '__main__':
    args = parse_args()
    main(args)