import jsonlines
import numpy as np
from transformers import AutoTokenizer

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-1M")

    samples_train_processed = []
    with jsonlines.open("train_sft.jsonl") as reader:
        for line in reader:
            samples_train_processed.append(line)

    samples_dev_processed = []
    with jsonlines.open("dev_sft.jsonl") as reader:
        for line in reader:
            samples_dev_processed.append(line)

    print("processed", len(samples_train_processed), len(samples_dev_processed))

    sql_complexities = set([])
    sql_types = set([])
    question_focuses = set([])
    samples_train_final = []
    samples_dev_final = []
    lengths_original_answer = []
    lengths_sampled_reasoning = []
    for sample in samples_train_processed + samples_dev_processed:
        sampled_answers = sample["sampled_answers"]
        original_answer = sample["answer"]
        tokenized_original_answer = tokenizer.encode(original_answer)
        lengths_original_answer.append(len(tokenized_original_answer))

        sample_validities = []
        for sampled_answer in sampled_answers:
            sampled_answer_output = sampled_answer["answer"]
            sampled_answer_reasoning = sampled_answer["reasoning"]
            sampled_answer_output = sampled_answer_output.strip()
            sampled_answer_output = " ".join(sampled_answer_output.split())
            sampled_answer_output = ", ".join([tmp.strip() for tmp in sampled_answer_output.split(",")])
            if original_answer == sampled_answer_output:
                sample_validities.append(1)
            else:
                sample_validities.append(0)
            tokenized_sampled_answer_reasoning = tokenizer.encode(sampled_answer_reasoning)
            lengths_sampled_reasoning.append(len(tokenized_sampled_answer_reasoning))

        if 1 <= sum(sample_validities) < 8:
            sql_complexities.add(len(sample["sql"].split()))
            sql_types.add(sample["sql_type"])
            question_focuses.add(sample["focus"])

            if sample["label"] == "training":
                samples_train_final.append(sample)

            if sample["label"] == "dev":
                samples_dev_final.append(sample)

    print("final data", len(samples_train_final), len(samples_dev_final)) # 1435 82
    # {4, 5, 8, 9, 10, 11, 12, 13, 14, 15}
    # {'multi_ran_aggregating', 'multi_graph_filtering', 'multi_ran_filtering_foa', 'multi_ran_filtering_foo', 'multi_ran_filtering_ofo'}
    # {'title_word_count', 'title_entire', 'author_count', 'author_relationship', 'author_list', 'reference_count', 'citation_relationship'}
    print(sql_complexities)
    print(sql_types)
    print(question_focuses)

    print("max answer tokens", np.max(lengths_original_answer), np.max(lengths_sampled_reasoning))
