
from pathlib import Path
import json
import numpy as np
from generate_rubrics_web import fix_double_quotes
import argparse

# args
parser = argparse.ArgumentParser()
parser.add_argument("--real_seed", type=int, default=42)
parser.add_argument("--num_samples", type=int, default=200)
parser.add_argument("--num_points", type=int, default=10)
args = parser.parse_args()

NUM_SAMPLES = args.num_samples
NUM_POINTS = args.num_points
SERVICE = "deepseek"
MODEL_NAME = 'deepseek-reasoner'

# path args
DATA_ROOT_PATH = Path('data/sentiment_analysis/synthetic_data')
OUTPUT_PATH = Path('./data/sentiment_analysis/Lens/rubrics')
if not OUTPUT_PATH.exists():
    OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
REAL_DATA_PATH = Path(f'./data/sentiment_analysis/real_data/balanced_real_seed={args.real_seed}.json')

SIM_PROMPT_PATH = Path(f'./prompt_templates/sentiment_analysis/rubric_compilation/sim.txt')
DIFF_PROMPT_PATH = Path(f'./prompt_templates/sentiment_analysis/rubric_compilation/diff.txt')
sim_prompt = open(SIM_PROMPT_PATH, 'r').read()
diff_prompt = open(DIFF_PROMPT_PATH, 'r').read()

print("Synthetic data")

def sample_balanced_data(data, num_samples, label_key="sentiment"):
    """
    Returns a class-balanced sample of the data.
    data: list of dicts, each with keys "headline" and "sentiment" (0,1,2)
    num_samples: total number of samples to return (should be divisible by 3 for perfect balance)
    """
    from collections import defaultdict
    import random
    random.seed(42)

    # Group data by sentiment
    sentiment_groups = defaultdict(list)
    for item in data:
        sentiment = int(item[label_key])
        sentiment_groups[sentiment].append(item)

    # Determine number of samples per class
    num_classes = 3
    samples_per_class = num_samples // num_classes

    # Sample from each class
    balanced_samples = []
    for sentiment in range(num_classes):
        group = sentiment_groups[sentiment]
        if len(group) < samples_per_class:
            raise ValueError(f"Not enough samples for sentiment {sentiment}: requested {samples_per_class}, available {len(group)}")
        balanced_samples.extend(random.sample(group, samples_per_class))

    # Shuffle the final result
    random.shuffle(balanced_samples)
    return balanced_samples

def preprocess_sentiment_analysis(x, real=False):
    if real:
        return x['text']
    else:
        return x['headline']

syn_db_dataset_map = {} # dataset_name -> dataset
data_paths = DATA_ROOT_PATH.glob('*.json')
for path in data_paths:
    # for each db_id
    dataset_name = path.stem
    # print(f"Processing {dataset_name} for {db_id}")
    if dataset_name not in syn_db_dataset_map:
        syn_db_dataset_map[dataset_name] = []
    with path.open('rt') as f:
        data = json.load(f)
    data = sample_balanced_data(data, NUM_SAMPLES, label_key="sentiment")
    filtered_data = map(lambda x: preprocess_sentiment_analysis(x, real=False), data)
    syn_db_dataset_map[dataset_name].extend(list(filtered_data))
    print(f"Loaded {len(syn_db_dataset_map[dataset_name])} examples for {dataset_name}")

print("-"*100)
print("Real data")

real_dataset = []
with REAL_DATA_PATH.open('rt') as f:
    data = json.load(f)
# real data is already balanced and sampled
real_dataset.extend(list(map(lambda x: preprocess_sentiment_analysis(x, real=True), data)))
print(f"Loaded {len(real_dataset)} examples")


import json
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

with open('./secrets.json') as f:
    secrets = json.load(f)
    
model = ChatOpenAI(
    openai_api_key=os.getenv('OPENAI_API_KEY'),
    model_name=MODEL_NAME,
    temperature=0.0
)

del secrets


similar_prompt = ChatPromptTemplate.from_messages([
    ("system", sim_prompt),
    ("user", "Samples from A:\n{A}\n\nSamples from B:\n{B}")
])

diff_prompt = ChatPromptTemplate.from_messages([
    ("system", diff_prompt),
    ("user", "Similar characteristics between A and B:\n{similar_points}\n\nSamples from A:\n{A}\n\nSamples from B:\n{B}")
])

parser = JsonOutputParser()

similar_chain = similar_prompt | model
diff_chain = diff_prompt | model

import os
from tqdm import auto as tqdm
import numpy as np


fout = f'rubric.{args.prompt_version}.sentiment_analysis.{MODEL_NAME.replace("/", "--")}_num_samples={NUM_SAMPLES}_num_points={NUM_POINTS}_real_seed={args.real_seed}.json'
fout = os.path.join(OUTPUT_PATH, fout)
if os.path.isfile(fout.replace('.json', '.partial.json')):
    print(f"Loading existing rubrics from {fout}")
    # load existing rubrics
    with open(fout.replace('.json', '.partial.json'), 'rt') as f:
        rubrics = json.load(f)
    sims = rubrics['sims']
    diffs_synth_from_real = rubrics['diffs_synth_from_real']
    diffs_real_from_synth = rubrics['diffs_real_from_synth']
elif os.path.isfile(fout):
    print(f"Loading existing rubrics from {fout}")
    # load existing rubrics
    with open(fout, 'rt') as f:
        rubrics = json.load(f)
    sims = rubrics['sims']
    diffs_synth_from_real = rubrics['diffs_synth_from_real']
    diffs_real_from_synth = rubrics['diffs_real_from_synth']
else:
    sims = {} # dataset_name -> sims
    diffs_synth_from_real = {} # dataset_name -> diffs
    diffs_real_from_synth = {} # dataset_name -> diffs

try:
    for dataset_name in tqdm.tqdm(syn_db_dataset_map.keys(), initial=len(sims)):
        print(f"Generating rubrics for {dataset_name}...")
        no_sims = True # we want to make sure for difference generation, we must have sims
        if dataset_name not in sims or len(sims.get(dataset_name, {})) == 0:
            result = similar_chain.invoke(dict(feedback='similar to', num=NUM_POINTS, A=json.dumps(real_dataset), B=json.dumps(syn_db_dataset_map[dataset_name]))).content
            sims[dataset_name] = parser.parse(fix_double_quotes(result))
        else:
            no_sims = False
            print(f'skipping {dataset_name} sims')
        if dataset_name not in diffs_synth_from_real or len(diffs_synth_from_real.get(dataset_name, {})) == 0 or no_sims:
            similar_points = "\n".join(sims[dataset_name])
            result = diff_chain.invoke(dict(feedback='different from', num=NUM_POINTS, A=json.dumps(real_dataset), B=json.dumps(syn_db_dataset_map[dataset_name]), similar_points=similar_points)).content
            diffs_synth_from_real[dataset_name] = parser.parse(fix_double_quotes(result))
        else:
            print(f'skipping {dataset_name} synth from real')
        if dataset_name not in diffs_real_from_synth or len(diffs_real_from_synth.get(dataset_name, {})) == 0 or no_sims:
            similar_points = "\n".join(sims[dataset_name])
            result = diff_chain.invoke(dict(feedback='different from', num=NUM_POINTS, B=json.dumps(real_dataset), A=json.dumps(syn_db_dataset_map[dataset_name]), similar_points=similar_points)).content
            diffs_real_from_synth[dataset_name] = parser.parse(fix_double_quotes(result))
        else:
            print(f'skipping {dataset_name} real from synth')            
except Exception as e:
    print(f"Error: {e}")
    # save partial results
    with open(fout.replace('.json', '.partial.json'), 'wt') as f:
        json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)
    raise e
with open(fout, 'wt') as f:
    json.dump(dict(sims=sims, diffs_synth_from_real=diffs_synth_from_real, diffs_real_from_synth=diffs_real_from_synth), f, indent=2)