import json
from datasets import DatasetDict, load_from_disk
from tqdm import tqdm
import argparse




# Create a function to transform the data format
def transform_format(example, idx):
    # Get the corresponding output from output_7b
    output_sample = output_7b[correct_i[idx]]
    
    # Extract conversations
    conversations = example.pop('conversations')
    conversations = json.loads(conversations)
    
    # Extract problem (human prompt) and response (GPT answer)
    problem = conversations[0]['value'] if len(conversations) > 0 and conversations[0]['from'] == 'human' else ""
    problem = problem.split('Description: ')[-1].split('Answer:')[0].strip()
    # response_text = conversations[1]['value'] if len(conversations) > 1 and conversations[1]['from'] == 'gpt' else ""
    response_text = output_sample['response']

    # Parse the ground truth answer from output_7b
    answer_str = output_sample['answer']
    answer_tuple = tuple(map(int, answer_str.replace('(','').replace(')','').split(',')))
    answer_tuple = int(answer_tuple[0]/1000*example['width']), int(answer_tuple[1]/1000*example['height'])
    answer_x, answer_y = answer_tuple

    # Replace the answer with the centralized coordinates
    center_response = response_text.split('<answer>')[0] + f'<answer>\n({answer_x},{answer_y})\n</answer>'

    
    # Get the prediction coordinates
    prediction = output_sample['coordinates']
    prediction = tuple(map(int, prediction))
    
    # Return the transformed format
    return {
        'problem': problem,
        'response': center_response,
        'answer': answer_tuple,
        'prediction': prediction,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', default='./datasets/output.jsonl')
    parser.add_argument('--dataset_dir', default='./datasets/uground_21k')
    parser.add_argument('--save_path', default='./datasets/uground_21k_transformed')
    args = parser.parse_args()

    output_dir = args.output_dir
    uground_21k = load_from_disk(args.dataset_dir)
    with open(output_dir, 'r') as f:
        lines = f.readlines()
        output = [json.loads(line) for line in lines]
    
    correct_i = []
    for i in tqdm(range(len(output))):
        margin = 50
        uground_sample = uground_21k[i]
        output_sample = output[i]
        w,h = uground_sample['width'], uground_sample['height']
        prediction = output_sample['coordinates']
        if prediction is None or prediction[0] is None:
            continue
        # Parse the answer from string to int tuple.index
        answer = output_sample['answer']
        answer = tuple(map(int,answer.replace('(','').replace(')','').split(',')))
        answer = answer[0]/1000*w, answer[1]/1000*h
        corrected = answer[0]-margin <= prediction[0] <= answer[0]+margin and answer[1]-margin <= prediction[1] <= answer[1]+margin
        if corrected:
            correct_i.append(i)

    # Filter uground_21k to keep only the correct predictions
    filtered_uground = uground_21k.select(correct_i)
    print(f"Filtered dataset size: {len(filtered_uground)} out of {len(uground_21k)}")

    # Apply the transformation
    transformed_dataset = filtered_uground.map(transform_format, with_indices=True)


    dataset_dict = DatasetDict({
        'train': transformed_dataset
    })
    # transformed_dataset

    # Save the filtered and transformed dataset
    save_path = args.save_path
    dataset_dict.save_to_disk(save_path)
    print(f"Saved transformed dataset to {save_path}")