# Remove yaml after transfer to pkl
from opencood.hypes_yaml.yaml_utils import load_yaml
import os
from tqdm import tqdm
import pickle
import numpy as np
from einops import rearrange
import cv2
import argparse

def data_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--dataroot", default='',
                        type=str, help='root path of V2XSim dataset')
    parser.add_argument("--x_size", default=512,
                        type=int, help="image width")
    parser.add_argument("--y_size", default=512,
                        type=int, help="image height")
    opt = parser.parse_args()
    return opt

if __name__ == "__main__":
    opt = data_parser()
    dataroot = opt.dataroot
    x_size = opt.x_size
    y_size = opt.y_size
    data_folder = ['train', 'test', 'val']

    for folder in data_folder:
        print('Start correct parameters for {} dataset'.format(folder))
        train_path = os.path.join(dataroot,folder)
        scenario_folders = sorted([os.path.join(train_path, x)
                                   for x in os.listdir(train_path) if
                                   os.path.isdir(
                                       os.path.join(train_path, x))])

        for (i, scenario_folder) in tqdm(enumerate(scenario_folders)):
            # if i>5:
            cav_list = [x for x in os.listdir(scenario_folder)
                                        if os.path.isdir(
                                    os.path.join(scenario_folder, x))]
            cav_list.sort()
            all_cords = []
            for (j, cav_id) in enumerate(cav_list):
                cav_path = os.path.join(scenario_folder, cav_id)
                yaml_files = []
                pkl_files = []
                time_stamp = []
                oslsdir = os.listdir(cav_path)
                oslsdir.sort()
                for x in oslsdir:
                    if x.endswith('.yaml'):
                        yaml_files.append(os.path.join(cav_path, x))
                        pkl_files.append(os.path.join(cav_path, x.replace('.yaml', '.pkl')))
                        time_stamp.append(x.split('.')[0])
                    elif x.endswith('.jpeg'):
                        if x.startswith('.'): # Remove abnormal images
                            os.remove(os.path.join(cav_path, x))
                            print(os.path.join(cav_path, x))
                        else:
                            # Rename images
                            old_path_name=os.path.join(cav_path, x)
                            cam_files = sorted([os.path.join(cav_path, x)
                                             for x in os.listdir(cav_path) if
                                             x.endswith('.jpeg')])
                            time_step, cam_idx=x.split('.')[0].split('_')
                            new_path_name=os.path.join(cav_path,time_step+'_camera%d.jpeg' % (int(cam_idx[-1]) - 1))
                            os.rename(old_path_name,new_path_name)

                            # Reformat images
                            img = cv2.imread(new_path_name)
                            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                            img = cv2.resize(img, (x_size, y_size))
                            npy_path_name = new_path_name.replace('.jpeg', '.npy')
                            np.save(npy_path_name, img)

                cords = []
                for yaml_file, pkl_file in zip(yaml_files, pkl_files):
                    data = load_yaml(yaml_file)
                    for i in range(4):
                        try:
                            # Correct camera name
                            if 'cam%d' in data:
                                data['camera%d' % i] = data.pop('cam%d' % (i+1))
                        except KeyError:
                            print("KeyError raised")

                    cords.append(data['true_ego_pose'][:2])
                    # Save annotation files as pkl format
                    with open(pkl_file, 'wb') as file:
                        # A new file will be created
                        pickle.dump(data, file)
                    os.remove(yaml_file)
                all_cords.append(cords)

            # Calc distances
            all_cords = np.stack(all_cords)
            positions = rearrange(all_cords, 'n t c-> t n c')
            pos_i = rearrange(positions, 't i c -> t i 1 c')
            pos_j = rearrange(positions, 't j c -> t 1 j c')
            diffs = pos_i - pos_j
            dists = np.linalg.norm(diffs, axis=-1)
            distfile = os.path.join(scenario_folder, 'distance.npy')
            np.save(distfile, dists)