from pathlib import Path
import json
import pandas as pd
import numpy as np
from typing import Dict, Any, List
import re
import argparse


def load_axtrees():
    # axtree path
    REAL_AXTREE_PATH = Path(f'./data/web_agent/sampled/axtree/real/axtree.json')
    SYN_AXTREE_PATH = Path(f'./data/web_agent/sampled/axtree/synthetic/axtree.json')
    # load axtrees
    with REAL_AXTREE_PATH.open('rt') as f:
        real_axtree_map = json.load(f)

    with SYN_AXTREE_PATH.open('rt') as f:
        syn_axtree_map = json.load(f)

    return real_axtree_map, syn_axtree_map


def load_datasets(args):
    # path args
    # we use the full synthetic dataset and sampled real dataset
    DATA_ROOT_PATH = Path(f'./data/web_agent/original/synthetic/')
    REAL_DATA_PATH = Path(f'./data/web_agent/sampled/real')
    print("Processing Synthetic data")
    syn_dataset_map = {} # website -> dataset_name -> data
    data_paths = DATA_ROOT_PATH.glob('*.json')
    for path in data_paths:
        website = re.search(r'site=(.+?)_num_tasks', path.name).group(1)
        if website not in syn_dataset_map:
            syn_dataset_map[website] = {}
        with path.open('rt') as f:
            data = json.load(f)
        syn_dataset_map[website][path.stem] = data
        print(f"Loaded {len(syn_dataset_map[website][path.stem])} examples for {website} -> {path.stem}")
    assert len(syn_dataset_map.keys()) == args.num_websites, f"Expected {args.num_websites} websites, got {len(syn_dataset_map.keys())}"
    total_num_examples = sum(len(dataset) for dataset in syn_dataset_map.values())
    assert total_num_examples == args.total_num_examples, f"Expected {args.total_num_examples} datasets, got {total_num_examples}"

    if args.cross_domain:
        final_syn_dataset_map = {website: {} for website in syn_dataset_map.keys()}
        for website in syn_dataset_map.keys():
            for website2 in syn_dataset_map.keys():
                final_syn_dataset_map[website].update(syn_dataset_map[website2])
        syn_dataset_map = final_syn_dataset_map
        assert len(syn_dataset_map.keys()) == args.num_websites, f"Expected {args.num_websites} websites, got {len(syn_dataset_map.keys())}"
        total_num_examples = sum(len(dataset) for dataset in syn_dataset_map.values())
        assert total_num_examples == args.total_num_examples, f"Expected {args.total_num_examples} datasets, got {total_num_examples}"


    print("-"*100)
    print("Processing Real data")

    real_dataset_map = {} # website -> data
    for path in REAL_DATA_PATH.glob('*.json'):
        website = path.stem.lower().split('_seed')[0]
        seed = re.search(r'seed=(\d+)', path.stem)
        if seed is not None:
            seed = int(seed.group(1))
        else:
            continue
        if seed != args.seed:
            continue
        print(f"Processing {website} with seed {seed}")
        with path.open('rt') as f:
            real_dataset = json.load(f)
        real_dataset = list(real_dataset.values())
        real_dataset_map[website] = real_dataset
        print(f"Loaded {len(real_dataset_map[website])} examples for {website}")

    assert len(real_dataset_map.keys()) == args.num_websites, f"Expected {args.num_websites} websites, got {len(real_dataset_map.keys())}"

    # assert all websites from synthetic and the real data are the same set
    assert set(syn_dataset_map.keys()) == set(real_dataset_map.keys()), \
        f"Expected the same websites, got {set(syn_dataset_map.keys())} and {set(real_dataset_map.keys())}"

    return syn_dataset_map, real_dataset_map


import json
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

def load_langchain_model(args):
    model = ChatOpenAI(
        openai_api_key=os.getenv('OPENAI_API_KEY'),
        model_name=args.model_name,
        temperature=0.0
    )

    similar_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a world class data analyst on analyzing user intents in web navigation tasks."),
        ("user", open(f'./prompt_templates/web_agent/rubric_compilation/sim.txt').read())
    ])

    diff_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a world class data analyst on analyzing user intents in web navigation tasks."),
        ("user", open(f'./prompt_templates/web_agent/rubric_compilation/diff.txt').read())
    ])

    parser = JsonOutputParser()

    similar_chain = similar_prompt | model
    diff_chain = diff_prompt | model

    return similar_chain, diff_chain, parser


def fix_double_quotes(text_output: str) -> str:
    """Fix double quotes in the text output"""
    # first capture strings within each ""
    pattern = r'"(.*?)"\,'
    matches = re.findall(pattern, text_output)
    for match in matches:
        # Replace double quotes with a single escaped double quote
        # escaped_match = match.replace('"', r'\"')
        escaped_match = match
        # Replace the original match with the escaped version in the text output
        text_output = text_output.replace(f'"{match}"', f'"{escaped_match}"')

    return text_output

def clean_text_output(text_output: str) -> str:
    """Clean the text output"""
    # First extract content between triple backticks if present
    if "```" in text_output:
        parts = text_output.split("```")
        # Get the content between first and second ```
        if len(parts) >= 3:
            text_output = parts[1].strip()
    else:
        text_output = text_output.strip()
        if text_output.lower().startswith('json'):
            text_output = text_output[len("json"):].strip()
        return fix_double_quotes(text_output)
    
    # Remove language identifier if it starts with 'j'
    if text_output.lower().startswith('json'):
        text_output = text_output[len("json"):].strip()
    
    return fix_double_quotes(text_output)


import os
from tqdm import auto as tqdm
import numpy as np

def generate_rubrics(
    syn_dataset_map,
    real_dataset_map,
    real_axtree_map,
    syn_axtree_map,
    similar_chain,
    diff_chain,
    parser,
    output_path,
    args
    ):
    fout = f'rubric.webvoyager.{args.model_name.replace("/", "--")}.axtree_points={args.num_points}_{args.prompt_version}_seed={args.seed}.json'
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    fout = os.path.join(output_path, fout)
    sims = {} # website -> dataset_name -> sims
    diffs_synth_from_real = {} # website -> dataset_name -> diffs
    diffs_real_from_synth = {} # website -> dataset_name -> diffs

    if os.path.isfile(fout):
        print(f"Loading existing rubrics from {fout}")
        # load existing rubrics
        with open(fout, 'rt') as f:
            rubrics = json.load(f)
        sims = rubrics['sims']
        diffs_synth_from_real = rubrics['diffs_synth_from_real']
        diffs_real_from_synth = rubrics['diffs_real_from_synth']
    elif os.path.isfile(fout.replace('.json', '.partial.json')):
        print(f"Loading partial rubrics from {fout.replace('.json', '.partial.json')}")
        # load partial rubrics
        with open(fout.replace('.json', '.partial.json'), 'rt') as f:
            rubrics = json.load(f)
        sims = rubrics['sims']
        diffs_synth_from_real = rubrics['diffs_synth_from_real']
        diffs_real_from_synth = rubrics['diffs_real_from_synth']
    else:
        print(f"No rubrics found, starting from scratch")

    try:
        for website in tqdm.tqdm(syn_dataset_map.keys(), initial=len(sims)):
            if website not in sims:
                sims[website] = {}
            if website not in diffs_synth_from_real:
                diffs_synth_from_real[website] = {}
            if website not in diffs_real_from_synth:
                diffs_real_from_synth[website] = {}
            print(f"Generating rubrics for {website}...")
            for dataset_name in tqdm.tqdm(syn_dataset_map[website].keys(), initial=len(sims)):
                print(f"Generating rubrics for {dataset_name}...")
                if dataset_name in sims[website] and dataset_name in diffs_synth_from_real[website] and dataset_name in diffs_real_from_synth[website]:
                    # check if the values are non-empty
                    if len(sims[website][dataset_name]) > 0 and len(diffs_synth_from_real[website][dataset_name]) > 0 and len(diffs_real_from_synth[website][dataset_name]) > 0:
                        print(f'skipping {dataset_name} because it already exists\n')
                        continue

                sims[website][dataset_name] = "" if dataset_name not in sims[website] else sims[website][dataset_name]
                diffs_synth_from_real[website][dataset_name] = "" if dataset_name not in diffs_synth_from_real[website] else diffs_synth_from_real[website][dataset_name]
                diffs_real_from_synth[website][dataset_name] = "" if dataset_name not in diffs_real_from_synth[website] else diffs_real_from_synth[website][dataset_name]

                # Compute similarities
                if sims[website][dataset_name] == "":
                    sim_output = similar_chain.invoke(dict(
                        feedback='similar to',
                        num=args.num_points,
                        A=json.dumps(real_dataset_map[website]),
                        B=json.dumps(syn_dataset_map[website][dataset_name]),
                        A_tree=json.dumps(real_axtree_map[website]),
                        B_tree=json.dumps(syn_axtree_map[dataset_name])
                    )).content
                    cleaned_sim_output = clean_text_output(sim_output)
                    try:
                        sims[website][dataset_name] = parser.parse(cleaned_sim_output)
                    except Exception as e:
                        print(f"Error parsing {website} {dataset_name}: {e}")
                        print(f"Cleaned output: {cleaned_sim_output}")
                        sims[website][dataset_name] = cleaned_sim_output

                # Compute differences using the similarities
                if type(sims[website][dataset_name]) == str:
                    similar_points = sims[website][dataset_name]
                elif type(sims[website][dataset_name]) == list:
                    similar_points = "\n".join(sims[website][dataset_name])
                else:
                    continue

                if diffs_synth_from_real[website][dataset_name] == "":
                    # Compute synth from real differences
                    diff_output = diff_chain.invoke(dict(
                        feedback='different from',
                        num=args.num_points,
                        A=json.dumps(real_dataset_map[website]),
                        B=json.dumps(syn_dataset_map[website][dataset_name]),
                        similar_points=similar_points,
                        A_tree=json.dumps(real_axtree_map[website]),
                        B_tree=json.dumps(syn_axtree_map[dataset_name])
                    )).content
                    cleaned_diff_output = clean_text_output(diff_output)
                    try:
                        diffs_synth_from_real[website][dataset_name] = parser.parse(cleaned_diff_output)
                    except Exception as e:
                        print(f"Error parsing {website} {dataset_name}: {e}")
                        print(f"Cleaned output: {cleaned_diff_output}")
                        diffs_synth_from_real[website][dataset_name] = cleaned_diff_output

                if diffs_real_from_synth[website][dataset_name] == "":
                    # Compute real from synth differences
                    diff_output = diff_chain.invoke(dict(
                        feedback='different from',
                        num=args.num_points,
                        B=json.dumps(real_dataset_map[website]),
                        A=json.dumps(syn_dataset_map[website][dataset_name]),
                        similar_points=similar_points,
                        A_tree=json.dumps(real_axtree_map[website]),
                        B_tree=json.dumps(syn_axtree_map[dataset_name])
                    )).content
                    cleaned_diff_output = clean_text_output(diff_output)
                    try:
                        diffs_real_from_synth[website][dataset_name] = parser.parse(cleaned_diff_output)
                    except Exception as e:
                        print(f"Error parsing {website} {dataset_name}: {e}")
                        print(f"Cleaned output: {cleaned_diff_output}")
                        diffs_real_from_synth[website][dataset_name] = cleaned_diff_output
    except Exception as e:
        print(f"Error: {e}")
        # save partial results
        with open(fout.replace('.json', '.partial.json'), 'wt') as f:
            json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)
        raise e
    with open(fout, 'wt') as f:
        json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)

def main(args):
    similar_chain, diff_chain, parser = load_langchain_model(args)
    syn_dataset_map, real_dataset_map = load_datasets(args)
    real_axtree_map, syn_axtree_map = load_axtrees()
    output_path = Path(f'./data/web_agent/Lens/rubrics')
    generate_rubrics(syn_dataset_map, real_dataset_map, real_axtree_map, syn_axtree_map, similar_chain, diff_chain, parser, output_path, args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="deepseek-reasoner")
    parser.add_argument("--num_websites", type=int, default=13)
    parser.add_argument("--total_num_examples", type=int, default=65)
    parser.add_argument("--num_points", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    main(args)