import os
import sys
from datasets import Dataset, load_dataset
from verl.utils.hdfs_io import copy, makedirs
import argparse
import json
import pandas as pd
import numpy as np
from transformers import AutoTokenizer

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.table_utils import encode_table, make_table_all_text
from utils.formula_utils import formula_operator_instruction


# replace `apply_chat_template(chat)` with `chat` in verl dataset processor to customize the prompt high-level
def construct_formula_prompt(dp, tokenizer):
    
    table_content = encode_table(dp['table'], with_address=True, remove_break_in_cell=True)
    question = dp['question']

    system_prompt_str = 'You are a helpful assistant.'

    user_prompt_str = f"""# Task
You are an expert in writing Spreadsheet formulas given a table and a question.
You first think about the reasoning process in the mind and then provides the user with the answer.
Your task is to generate the correct spreadsheet formula to answer a given question, based on the provided table.

# Spreadsheet Formula Operator List
Below is a JSON list of commonly used formula operators, including their instructions and examples.
{formula_operator_instruction}

# Table
The table is represented as cell-value pairs, where each pair consists of a cell address and its content, separated by a comma (e.g., 'A1,Year').
Multiple cells are separated by a pipe symbol '|' (e.g., 'A1,Year|A2,Profit').
Empty cell of A1 can be represented as 'A1,|A2,Profit'.

Here is the table:
{table_content}

# Response Format
Show your reasoning within <think> </think> tags. Your final output must be in JSON format, enclosed in <answer> </answer> tags. For example:
<think>
[step-by-step reasoning]
</think>
<answer>
{{
    "formula": "=......."
}}
</answer>

# Notes
1. For simple questions, if a direct cell reference is appropriate, simply return the formula as =CellAddress.
2. Construct the formula mainly using the provided operator symbols from the formula operator list.
3. You may either use cell references (cell addresses) in formulas or use the actual cell values directly.
4. Do not use the dollar sign ($) in addresses; use only formats like A1, A2, etc.
5. If a question has multiple answers, concatenate them using ", " as the separator. For example, use the formula `=A1 & ", " & A2 & ", " & A3` to produce a single string like `a, b, c`.
6. The execution result of the generated formula must be the direct final answer to the question.

Here is the question:
{question}
"""

    assistant_prefix_str = 'Let me write the spreadsheet formula with reasoning.\n<think>\n'

    prompt_str = tokenizer.apply_chat_template([
        # {"role": "system", "content": system_prompt_str},
        {"role": "user", "content": user_prompt_str},
    ], add_generation_prompt=True, tokenize=False) + assistant_prefix_str

    # print('--------------------------------')
    # print(prompt_str)
    # print('--------------------------------')
    return prompt_str



def construct_text_prompt(dp, tokenizer):
    
    table_content = encode_table(dp['table'], with_address=True, remove_break_in_cell=True)
    question = dp['question']

    system_prompt_str = 'You are a helpful assistant.'

    user_prompt_str = f"""# Task
You are an expert in answering questions given a table.
You first think about the reasoning process in the mind and then provides the user with the answer.
Your task is to generate the correct answer to a given question, based on the provided table.

# Table
The table is represented as cell-value pairs, where each pair consists of a cell address and its content, separated by a comma (e.g., 'A1,Year').
Multiple cells are separated by a pipe symbol '|' (e.g., 'A1,Year|A2,Profit').
Empty cell of A1 can be represented as 'A1,|A2,Profit'.

Here is the table:
{table_content}

# Response Format
Show your reasoning within <think> </think> tags. Your final output must be in JSON format, enclosed in <answer> </answer> tags. For example:
<think>
[step-by-step reasoning]
</think>
<answer>
{{
    "answer": "......."
}}
</answer>

# Notes
1. Use the values from the table in the reasoning process or answer the question.
2. If a question has multiple answers, concatenate them using ", " as the separator, e.g., "a, b, c".
3. Your answer cannot be the spreadsheet formula.

Here is the question:
{question}
"""

    assistant_prefix_str = 'Let me give the answer with reasoning.\n<think>\n'

    prompt_str = tokenizer.apply_chat_template([
        # {"role": "system", "content": system_prompt_str},
        {"role": "user", "content": user_prompt_str},
    ], add_generation_prompt=True, tokenize=False) + assistant_prefix_str

    # print('--------------------------------')
    # print(prompt_str)
    # print('--------------------------------')
    return prompt_str



def load_customized_dataset():
    table = {}

    with open("data/raw_data/aitqa/AITQA/raw_data/aitqa_tables.jsonl", "r") as f:
        for line in f:
            D = json.loads(line)

            col_hdr = D["column_header"]
            row_hdr = D["row_header"]
            data = D["data"]

            n_rows = max(len(row_hdr), len(data))
            while len(row_hdr) < n_rows:
                row_hdr.append([])
            while len(data) < n_rows:
                data.append([])

            n_cols = max(len(col_hdr), max((len(r) for r in data), default=0))
            while len(col_hdr) < n_cols:
                col_hdr.append([])
            for r in data:
                r.extend([""] * (n_cols - len(r)))

            left_depth = max((len(r) for r in row_hdr), default=0)
            top_depth  = max((len(c) for c in col_hdr), default=0)

            table_rows = []
            for lvl in range(top_depth):
                row = [""] * left_depth
                for col in col_hdr:
                    row.append(col[lvl] if lvl < len(col) else "")
                table_rows.append(row)

            for rh, body in zip(row_hdr, data):
                rh_cells = rh + [""] * (left_depth - len(rh))
                body_cells = body
                table_rows.append(rh_cells + body_cells)

            table[D["id"]] = make_table_all_text(table_rows)

    with open(f'data/raw_data/aitqa/AITQA/raw_data/aitqa_questions.jsonl', 'r') as f:
        test_data = [json.loads(line) for line in f]


    test_data = [{'table_id': x['table_id'], 'table': table[x['table_id']], 'question': x['question'], 'answer': [str(a) for a in x['answers']]} for x in test_data]
    train_data = test_data[:2]
    val_data = test_data[:2]
    
    
    return train_data, test_data, val_data


if __name__ == '__main__':

    dataset = 'aitqa'
    filter_prompt_length = 8192

    for task in ['formula', 'text']:
        for model_name in ['qwen', 'llama']:
            print(f"Processing {dataset} with {model_name} for {task} task...")

            if model_name == 'qwen':
                model = 'Qwen/Qwen2.5-Coder-7B-Instruct'
            elif model_name == 'llama':
                model = 'meta-llama/Llama-3.1-8B-Instruct'


            parser = argparse.ArgumentParser()
            parser.add_argument('--local_dir', default=f'data/processed_data/{task}/{model_name}')
            parser.add_argument('--hdfs_dir', default=None)
            parser.add_argument('--dataset', type=str, default=dataset)
            parser.add_argument('--model', type=str, default=model)

            args = parser.parse_args()
            
            data_source = args.dataset
            
            tokenizer = AutoTokenizer.from_pretrained(args.model)

            train_data, test_data, val_data = load_customized_dataset()

            train_dataset = Dataset.from_list(train_data)
            test_dataset = Dataset.from_list(test_data)
            val_dataset = Dataset.from_list(val_data)

            # generation_prompt = tokenizer.apply_chat_template([{"role": "system", "content": ""}], tokenize=False, add_generation_prompt=True).strip().split('\n')[-1]
            # print(f"Generation prompt: {generation_prompt}")
            
            def make_map_fn(split):
                def process_fn(example, idx):
                    if task == 'formula':
                        prompt = construct_formula_prompt(example, tokenizer)
                    elif task == 'text':
                        prompt = construct_text_prompt(example, tokenizer)
                    data = {
                        "data_source": f"{task}_{data_source}_{split}",
                        "prompt": prompt,
                        "ability": f'{task}_TableQA',
                        "reward_model": {
                            "style": "rule",
                            "ground_truth": {
                                'answer': example['answer'],
                            }
                        },
                        "extra_info": {
                            'split': split,
                            'index': idx,
                            'table_id': example['table_id'],
                            'table': example['table']
                        }
                    }
                    return data
                return process_fn
            
            train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
            test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
            val_dataset = val_dataset.map(function=make_map_fn('val'), with_indices=True)

            # shuffle the dataset
            train_dataset = train_dataset.shuffle(seed=42)
            test_dataset = test_dataset.shuffle(seed=42)
            val_dataset = val_dataset.shuffle(seed=42)

            # statistic
            def calculate_token_lengths(dataset, tokenizer):
                lengths_list = []
                for d in dataset:
                    # Tokenize the 'prompt' using the tokenizer and count the number of tokens
                    tokens = tokenizer(d['prompt'])
                    lengths_list.append(len(tokens['input_ids']))  # 'input_ids' gives the token count
                return lengths_list

            # Calculate token lengths for train, test, and validation datasets
            lengths_list_train = calculate_token_lengths(train_dataset, tokenizer)
            lengths_list_test = calculate_token_lengths(test_dataset, tokenizer)
            lengths_list_val = calculate_token_lengths(val_dataset, tokenizer)
                
            if filter_prompt_length is not None:
                print(f"Test dataset size after filtering: {test_dataset.num_rows}")
                test_dataset = test_dataset.filter(
                    lambda d, idx: lengths_list_test[idx] <= filter_prompt_length,
                    with_indices=True
                )
                lengths_list_test = [l for l in lengths_list_test if l <= filter_prompt_length]
                print(f"Test dataset size after filtering: {test_dataset.num_rows}")


            print('Final statistics:')
            print(f"Data size of train dataset: {len(train_dataset)}")
            print(f"Data size of test dataset: {len(test_dataset)}")
            print(f"Data size of val dataset: {len(val_dataset)}")

            print(f"Average length of train dataset: {sum(lengths_list_train) / len(lengths_list_train)}")
            print(f"Average length of test dataset: {sum(lengths_list_test) / len(lengths_list_test)}")
            print(f"Average length of val dataset: {sum(lengths_list_val) / len(lengths_list_val)}")
            
            print(f"Max length of train dataset: {max(lengths_list_train)}")
            print(f"Max length of test dataset: {max(lengths_list_test)}")
            print(f"Max length of val dataset: {max(lengths_list_val)}")

            local_dir = os.path.join(args.local_dir, f"{args.dataset}")
            hdfs_dir = os.path.join(args.hdfs_dir, f"{args.dataset}") if args.hdfs_dir is not None else None
            
            os.makedirs(local_dir, exist_ok=True)
            
            train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
            test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
            val_dataset.to_parquet(os.path.join(local_dir, 'val.parquet'))
            
            print('=' * 100)

            if hdfs_dir is not None:
                makedirs(hdfs_dir)
                copy(src=local_dir, dst=hdfs_dir) 
