import json
import os
import sys
import time
import argparse
import pickle
import random
import shutil

import hydra
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
import nocturne
import pdb
from tqdm import tqdm
from scipy.spatial import distance

def match_md_to_nocturne(md_sdc_pos, md_adv_pos, nocturne_dict):
    matched = not(md_sdc_pos[0] == 0 or md_adv_pos[0] == 0)
    nocturne_sdc_id = None 
    nocturne_adversary_id = None

    nocturne_interactive_ids = nocturne_dict['interactive_ids']
    veh_pos_0 = np.array([nocturne_dict['objects'][nocturne_interactive_ids[0]]['position'][0]['x'], nocturne_dict['objects'][nocturne_interactive_ids[0]]['position'][0]['y']])
    veh_pos_1 = np.array([nocturne_dict['objects'][nocturne_interactive_ids[1]]['position'][0]['x'], nocturne_dict['objects'][nocturne_interactive_ids[1]]['position'][0]['y']])

    if matched:
        dist_0_to_md_sdc = np.linalg.norm(veh_pos_0 - md_sdc_pos)
        dist_1_to_md_sdc = np.linalg.norm(veh_pos_1 - md_sdc_pos)

        if dist_0_to_md_sdc < dist_1_to_md_sdc:
            nocturne_sdc_id = nocturne_interactive_ids[0]
            nocturne_adversary_id = nocturne_interactive_ids[1]
            nocturne_sdc_pos = veh_pos_0 
            nocturne_adversary_pos = veh_pos_1
        else:
            nocturne_sdc_id = nocturne_interactive_ids[1]
            nocturne_adversary_id = nocturne_interactive_ids[0]
            nocturne_sdc_pos = veh_pos_1
            nocturne_adversary_pos = veh_pos_0

        if not (np.linalg.norm(md_sdc_pos - nocturne_sdc_pos) < 0.01 and np.linalg.norm(md_adv_pos - nocturne_adversary_pos) < 0.01):
            matched=False
    
    return nocturne_sdc_id, nocturne_adversary_id, matched

# we need to split the val set into a val set and a test set (where we select 2500 scenes for the test set)
@hydra.main(version_base=None, config_path="/home/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    test_filenames = os.listdir(cfg.nocturne_waymo_val_interactive_folder)
    test_filenames = [file for file in test_filenames if 'tfrecord' in file]
    test_filenames = sorted(test_filenames)

    seed = 2024
    random.seed(seed)  # Python random module.
    np.random.seed(seed)  # Numpy module.
    torch.manual_seed(seed)  # PyTorch.
    torch.cuda.manual_seed(seed)  # PyTorch, for CUDA.

    file_ids = list(np.arange(len(test_filenames)))
    random.shuffle(file_ids)

    md_file_path = '/scratch/md_womd_validation_interactive'
    md_test_filenames = os.listdir(md_file_path)
    md_test_filenames_id = [name[36:-21] for name in md_test_filenames]
    
    if not os.path.exists('/scratch/raw_scenes'):
        os.makedirs('/scratch/raw_scenes', exist_ok=True)
    output_dict = {}

    count = 0
    for file_id in tqdm(file_ids):
        file = test_filenames[file_id]
        if file[:-5] not in md_test_filenames_id:
            continue   
            
        md_idx = md_test_filenames_id.index(file[:-5])
        md_path = os.path.join(md_file_path, md_test_filenames[md_idx])
        with open(md_path, 'rb') as f:
            md_dict = pickle.load(f)
            
            track_len = md_dict['metadata']['track_length']
            if track_len != 91:
                continue
            
            interactive_ids = md_dict['metadata']['objects_of_interest']
            md_sdc_id = md_dict['metadata']['sdc_id']
            
            if md_sdc_id not in interactive_ids:
                continue

            interactive_ids.remove(md_sdc_id)
            md_adversary_id = interactive_ids[0]
        
        nocturne_path = os.path.join(cfg.nocturne_waymo_val_interactive_folder, file)
        with open(nocturne_path, 'r') as f:
            nocturne_dict = json.load(f)
            nocturne_sdc_id, nocturne_adversary_id, matched = match_md_to_nocturne(md_dict['tracks'][md_sdc_id]['state']['position'][0, :2], 
                                                                                   md_dict['tracks'][md_adversary_id]['state']['position'][0, :2],
                                                                                   nocturne_dict)
            if not matched:
                continue

        new_md_path = f'/scratch/raw_scenes/{count}.pkl'
        os.system(f'cp {md_path} {new_md_path}')
        
        output_dict[count] = {
            'md_path': new_md_path,
            'nocturne_path': nocturne_path,
            'md_sdc_id': md_sdc_id,
            'md_adversary_id': md_adversary_id,
            'nocturne_sdc_id': nocturne_sdc_id,
            'nocturne_adversary_id': nocturne_adversary_id
        }
        count += 1
    
    with open('/home/ctrl-sim-dev/eval_planner_dict.pkl', 'wb') as f:
        pickle.dump(output_dict, f)
    print("Number of valid scenarios:", count)

main()