
from pathlib import Path
import json
import pandas as pd
import numpy as np
from generate_rubrics_web import fix_double_quotes
import argparse

# args
parser = argparse.ArgumentParser()
parser.add_argument("--real_seed", type=int, default=42)
parser.add_argument("--num_synthetic_samples", type=int, default=30)
parser.add_argument("--num_points", type=int, default=10)
args = parser.parse_args()

NUM_POINTS = args.num_points
SERVICE = "deepseek"
MODEL_NAME = 'deepseek-reasoner'
OUTPUT_PATH = './data/text2sql/Lens/rubrics'


print("Synthetic data")

def preprocess_text2sql(x):
    return x['question']
syn_db_dataset_map = {} # db_id -> dataset_name -> dataset
data_root_paths = Path('./data/text2sql/data').glob('*_*')
np.random.seed(42) # just for synthetic data
for db_path in data_root_paths:
    # for each db_id
    db_id = db_path.stem
    syn_db_dataset_map[db_id] = {}
    for dataset_path in db_path.glob('*.json'):
        dataset_name = dataset_path.stem
        if dataset_name not in syn_db_dataset_map[db_id]:
            syn_db_dataset_map[db_id][dataset_name] = []
        with dataset_path.open('rt') as f:
            data = json.load(f)
            filtered_data = map(lambda x: preprocess_text2sql(x), data)
            syn_db_dataset_map[db_id][dataset_name].extend(np.random.choice(list(filtered_data), args.num_synthetic_samples, replace=False).tolist())
            print(f"Loaded {len(syn_db_dataset_map[db_id][dataset_name])} examples for {dataset_name} in {db_id}")

print("-"*100)
print("Real data")

real_dataset_map = {} # db_id -> dataset
real_data_paths = Path('./data/text2sql/data/real').glob(f'*seed={args.real_seed}.json')
for path in real_data_paths:
    db_id = path.stem.replace('dev_', '').replace(f'_seed={args.real_seed}', '')
    with path.open('rt') as f:
        data = json.load(f)
    real_dataset_map[db_id] = list(map(lambda x: preprocess_text2sql(x), data)) # data is already sampled
    print(f"Loaded {len(real_dataset_map[db_id])} examples for {db_id}")
assert len(real_dataset_map.keys()) == 3, f"Expected 3 real datasets, got {len(real_dataset_map.keys())}"

import json
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

with open('./secrets.json') as f:
    secrets = json.load(f)

model = ChatOpenAI(
    openai_api_key=os.getenv('OPENAI_API_KEY'),
    model_name=MODEL_NAME,
    temperature=0.0
)

del secrets

similar_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a world class data analyst on database queries in natural language."),
    ("user", open(f'./prompt_templates/text2sql/rubric_compilation/sim.txt').read())
])

diff_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a world class data analyst on database queries in natural language."),
    ("user", open(f'./prompt_templates/text2sql/rubric_compilation/diff.txt').read())
])

parser = JsonOutputParser()

similar_chain = similar_prompt | model
diff_chain = diff_prompt | model


# sanity check
assert syn_db_dataset_map.keys() == real_dataset_map.keys(), f"syn_db_dataset_map.keys() != real_dataset_map.keys(): {syn_db_dataset_map.keys()} != {real_dataset_map.keys()}"


import os
from tqdm import auto as tqdm
import numpy as np


fout = f'rubric.text2sql.{MODEL_NAME.replace("/", "--")}_num_points={NUM_POINTS}_num_real_samples=30_num_synthetic_samples={args.num_synthetic_samples}_seed={args.real_seed}.json'
fout = os.path.join(OUTPUT_PATH, fout)
if os.path.isfile(fout.replace('.json', '.partial.json')):
    print(f"Found partial results in {fout.replace('.json', '.partial.json')}\nloading partial results...")
    with open(fout.replace('.json', '.partial.json'), 'rt') as f:
        tmp = json.load(f)
        sims = tmp['sims']
        diffs_synth_from_real = tmp['diffs_synth_from_real']
        diffs_real_from_synth = tmp['diffs_real_from_synth']
else:
    # if the path doesn't exist, create the folder
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    sims = {} # db_id -> dataset_name -> diffs
    diffs_synth_from_real = {} # db_id -> dataset_name -> diffs
    diffs_real_from_synth = {} # db_id -> dataset_name -> diffs

db_ids = list(syn_db_dataset_map.keys())

try:
    for db_id in tqdm.tqdm(db_ids):
        if db_id not in sims:
            sims[db_id] = {}
        if db_id not in diffs_synth_from_real:
            diffs_synth_from_real[db_id] = {}
        if db_id not in diffs_real_from_synth:
            diffs_real_from_synth[db_id] = {}
        for dataset_name, dataset in tqdm.tqdm(list(syn_db_dataset_map[db_id].items())):
            if dataset_name not in sims[db_id]:
                output = similar_chain.invoke(dict(feedback='similar to', num=NUM_POINTS, A=json.dumps(real_dataset_map[db_id]), B=json.dumps(dataset))).content
                cleaned_output = fix_double_quotes(output)
                try:
                    sims[db_id][dataset_name] = parser.parse(cleaned_output)
                except Exception as e:
                    print(f"Error parsing {dataset_name} sims: {e}")
                    sims[db_id][dataset_name] = cleaned_output
            else:
                print(f'skipping {dataset_name} sims')
            if dataset_name not in diffs_synth_from_real[db_id]:
                similar_points = "\n".join(sims[db_id][dataset_name])
                output = diff_chain.invoke(dict(feedback='different from', num=NUM_POINTS, A=json.dumps(real_dataset_map[db_id]), B=json.dumps(dataset), similar_points=similar_points)).content
                cleaned_output = fix_double_quotes(output)
                try:
                    diffs_synth_from_real[db_id][dataset_name] = parser.parse(cleaned_output)
                except Exception as e:
                    print(f"Error parsing {dataset_name} synth from real: {e}")
                    diffs_synth_from_real[db_id][dataset_name] = cleaned_output
            else:
                print(f'skipping {dataset_name} synth from real')
            if dataset_name not in diffs_real_from_synth[db_id]:
                similar_points = "\n".join(sims[db_id][dataset_name])
                output = diff_chain.invoke(dict(feedback='different from', num=NUM_POINTS, B=json.dumps(real_dataset_map[db_id]), A=json.dumps(dataset), similar_points=similar_points)).content
                cleaned_output = fix_double_quotes(output)
                try:
                    diffs_real_from_synth[db_id][dataset_name] = parser.parse(cleaned_output)
                except Exception as e:
                    print(f"Error parsing {dataset_name} real from synth: {e}")
                    diffs_real_from_synth[db_id][dataset_name] = cleaned_output
            else:
                print(f'skipping {dataset_name} real from synth')
except (Exception, KeyboardInterrupt) as e:
    print(f"At {db_id}, {dataset_name}: Error: {e}")
    # save partial results
    with open(fout.replace('.json', '.partial.json'), 'wt') as f:
        json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)
    raise e
with open(fout, 'wt') as f:
    json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)