from typing import List, Tuple, Dict
import os
import copy
import re
import json
import argparse
from collections import Counter, defaultdict

from rich.console import Console

from utils import fixed_seed, normalize_answer


console = Console()


class Parser:

    PATTERN = r'^(?:\d+\.\s+)?\"?(?P<utterance>.*?)\"?\s+\|\s+(?P<speaker>.*?)(?:\s+\|\s+(?:Rationale:\s+)?(?P<rationale>.*?))?$'
    
    def parse(pred: str) -> Dict:
        pred = pred.strip()
        matches = re.finditer(Parser.PATTERN, pred, re.MULTILINE)
        results = []
        for match in matches:
            utter = match.group('utterance')
            speaker = match.group('speaker') #match.group('speaker1') or match.group('speaker2')
            rationale = match.group('rationale')
            #print(f"Utterance: {utter}, Speaker: {speaker}, Rationale: {rationale}")

            # for wo rationale version, we discard a last token '"'
            speaker = speaker.replace('"', '')
            results.append({
                'utterance': utter,
                'speaker': speaker,
                'rationale': rationale
            })

        return results

def parse_args():
    parser = argparse.ArgumentParser(description="evaluating the ability of image-sharing behavior")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--template_type", type=str, default="image-sharing-v2")
    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--eval_result_dir", type=str, default="./logs")
    parser.add_argument("--file_version", type=str, default="v1.0")
    parser.add_argument("--report_save_dir", type=str, default=None)
    parser.add_argument("--datatype", type=str, default="test")
    parser.add_argument("--rounding_step", type=int, default=1)

    return parser.parse_args()

def load_json(datadir):
    with open(datadir, 'r') as f:
        return json.load(f)

def remove_speaker_from_parsed_utterance(utterance, spk1, spk2):
    if spk1 in utterance:
        return utterance.split(f'{spk1}: ')[-1]
    elif spk2 in utterance:
        return utterance.split(f'{spk2}: ')[-1]
    else:
        return utterance

def parse_completion_gpt(result):
    pred = result["task1_openai_resp"]
    parsed_result = Parser.parse(pred)
    
    # for manual parsing
    if len(parsed_result) == 0:
        print(result["prompt_input"])
        print(result["image_share_turn_idx"])
        print(pred)
        print()
    
    pattern = re.compile(r'The following is a dialogue between (?P<spk1>.*?) and (?P<spk2>.*?)\..*')
    match = pattern.search(result["prompt_input"])
    spk1 = match.group("spk1")
    spk2 = match.group("spk2")
    re_parsed_result = []
    
    for ele_p in parsed_result:
        utter = ele_p["utterance"]
        speaker = ele_p["speaker"]
        rationale = ele_p["rationale"]

        removed_utter = remove_speaker_from_parsed_utterance(utter, spk1, spk2)
        
        re_parsed_result.append({
            'utterance': removed_utter,
            'speaker': speaker,
            'rationale': rationale
        })

    result['task1_parsed_result'] = re_parsed_result

    return result, [{'pred': pred}] + re_parsed_result

def parse_chat_gpt(result, rounding_step=1):
    re_parsed_result = []
    for step in range(rounding_step):
        pred = result[f"{step+1}_task1_openai_resp"]
        parsed_result = Parser.parse(pred)
        
        # for manual parsing
        if len(parsed_result) == 0:
            print(result["prompt_input"])
            print(result["image_share_turn_idx"])
            print(pred)
            print()
        
        pattern = re.compile(r'The following is a dialogue between (?P<spk1>.*?) and (?P<spk2>.*?)\..*')
        match = pattern.search(result["prompt_input"])
        spk1 = match.group("spk1")
        spk2 = match.group("spk2")
        
        
        for ele_p in parsed_result:
            utter = ele_p["utterance"]
            speaker = ele_p["speaker"]
            rationale = ele_p["rationale"]

            removed_utter = remove_speaker_from_parsed_utterance(utter, spk1, spk2)
            
            re_parsed_result.append({
                'utterance': removed_utter,
                'speaker': speaker,
                'rationale': rationale
            })

    result[f'{rounding_step}_task1_parsed_result'] = re_parsed_result

    return result, [{'pred': pred}] + re_parsed_result
    
def main():
    args = parse_args()
    fixed_seed(args.seed)

    console.log(f"[bold red]{args.model_name} start!")

    if "text" in args.model_name:
        eval_result_dir = os.path.join(
            args.eval_result_dir, 
            args.file_version, 
            args.model_name,
            args.datatype,
            str(args.seed), 
            "{}_generation.json".format(args.template_type)
        )
    else:
        eval_result_dir = os.path.join(
            args.eval_result_dir, 
            args.file_version, 
            args.model_name,
            args.datatype,
            str(args.seed), 
            "{}_{}_generation.json".format(args.rounding_step, args.template_type)
        )

    results = load_json(eval_result_dir)

    total_results = []
    only_parsed_results = []
    for i, result in enumerate(results):
        if 'text' in args.model_name:
            result, only_parsed_result = parse_completion_gpt(result)
        else:
            result, only_parsed_result = parse_chat_gpt(result, rounding_step=args.rounding_step)
        
        only_parsed_results.append(only_parsed_result)
        total_results.append(result)
    
    report_save_dir = os.path.join('./parsed_results', args.file_version, args.model_name)
    os.makedirs(report_save_dir, exist_ok=True)

    if "text" in args.model_name:
        with open(os.path.join(report_save_dir, f'{args.template_type}_only_parsed_result.json'), 'w') as f:
            json.dump(only_parsed_results, f, ensure_ascii=False, indent="\t")
        
        with open(os.path.join(report_save_dir, f'{args.template_type}_parsed_result.json'), 'w') as f:
            json.dump(total_results, f, ensure_ascii=False, indent="\t")
    else:
        with open(os.path.join(report_save_dir, f'{args.rounding_step}_{args.template_type}_only_parsed_result.json'), 'w') as f:
            json.dump(only_parsed_results, f, ensure_ascii=False, indent="\t")
        
        with open(os.path.join(report_save_dir, f'{args.rounding_step}_{args.template_type}_parsed_result.json'), 'w') as f:
            json.dump(total_results, f, ensure_ascii=False, indent="\t")

    
if __name__ == '__main__':
    main()