import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import numpy as np
import glob
import os
import tensorflow as tf
from tqdm import tqdm
import cv2
import os
from multiprocessing import Pool, cpu_count
import pickle
from concurrent.futures import ThreadPoolExecutor
import re
from multiprocessing import Pool
import math

from dlimp.utils import read_resize_encode_image, tensor_feature

IMAGE_SIZE = (128, 128)


keys = []


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_sort_key(s):
    return [atoi(c) for c in re.split(r"(\d+)", s)]


def rotate_positions(relative_positions, yaws):
    # Create rotation matrices for all yaws at once
    cos_yaws = np.cos(yaws).squeeze()
    sin_yaws = np.sin(yaws).squeeze()

    # Create a stack of 2x2 rotation matrices
    rotation_matrices = np.array([[cos_yaws, -sin_yaws], [sin_yaws, cos_yaws]])
    rotation_matrices = np.transpose(rotation_matrices, (2, 0, 1))

    # Multiply rotation matrices with relative positions
    rotated_positions = np.einsum("ijk,ik->ij", rotation_matrices, relative_positions)

    return rotated_positions


def split_array(arr: np.ndarray, max_chunk_size=64):
    """Splits the array into chunks of size chunk_size or less."""
    num_chunks = math.ceil(len(arr) / max_chunk_size)
    float_chunk_size = len(arr) / num_chunks

    def get_chunk_idx(i: int):
        return int(i * float_chunk_size + 0.5)

    return [arr[get_chunk_idx(i) : get_chunk_idx(i + 1)] for i in range(num_chunks)]


def write_trajectories(output_base_path: str, features: dict):
    # Split the features into chunks of size 64
    chunked_features = {k: split_array(v) for k, v in features.items()}
    num_chunks = len(chunked_features["actions"])

    chunks = [{k: v[i] for k, v in chunked_features.items()} for i in range(num_chunks)]

    for i, chunk in enumerate(chunks):
        writer = tf.io.TFRecordWriter(output_base_path + f"_{i}.tfrecord")

        feature = {k: tensor_feature(v) for k, v in chunk.items()}
        goal = [chunk['images'][-1] for _ in range(len(chunk['images']))]
        feature['goal'] = tensor_feature(goal)

        example = tf.train.Example(
            features=tf.train.Features(
                feature=feature
            )
        )
        writer.write(example.SerializeToString())
        writer.close()


def create_tfrecord(pair, scaling_constant):
    """
    Handles one single trajectory

    input_path: str - the directory with the trajectory
    output_path: str - the filepath out the output tfrecord
    """
    input_path, output_path = pair
    pkl_path = os.path.join(input_path, "traj_data.pkl")
    if os.path.exists(pkl_path) is False:
        return
    with open(pkl_path, "rb") as file:
        pkl_data = pickle.load(file)
    
    for k, v in pkl_data.items():
        if v.dtype == np.object_:
            try:
                pkl_data[k] = v.astype(np.float32)
            except:
                print(f"Could not convert {k}: {v} to float32")
                return
        elif v.dtype == np.float64:
            pkl_data[k] = v.astype(np.float32)

    image_paths = glob.glob(os.path.join(input_path, "*.jpg"))
    # sort based on transition number
    image_paths = sorted(image_paths, key=natural_sort_key)

    if len(image_paths) < 21:
        return

    try:
        images = [read_resize_encode_image(path, IMAGE_SIZE) for path in image_paths][
            :-1
        ]
    except:
        return

    rel_pos = pkl_data["position"][1:] - pkl_data["position"][:-1]
    rel_pos = rel_pos / scaling_constant
    next_angle = pkl_data["yaw"][1:].squeeze()
    angle = -1 * pkl_data["yaw"][:-1].squeeze()

    headings = np.stack([np.cos(next_angle), np.sin(next_angle)], axis=1)
    unrotated_headings = headings
    headings = rotate_positions(headings, angle)
    delta_pos = rotate_positions(rel_pos, angle)

    action = np.concatenate([delta_pos, headings], axis=1)

    features = {
        "unrotated_headings": unrotated_headings,
        "mask": np.ones(len(action)),
        "unrotated_delta_pos": rel_pos,
        "actions": action,
        "delta_pos": delta_pos,
        "headings": headings,
        "images": images,
    }

    write_trajectories(output_path, features)


def append_zeros_multi_dim(arr: np.ndarray) -> np.ndarray:
    """Append zeros to the end of a multi-dimensional array along the first axis."""
    zeros_shape = (1,) + arr.shape[1:]
    zeros_to_append = np.zeros(zeros_shape)
    return np.concatenate([arr, zeros_to_append], axis=0)


if __name__ == "__main__":
    import argparse
    import sys
    import tqdm
    import yaml
    from functools import partial

    parser = argparse.ArgumentParser()
    parser.add_argument("--input_root", "-i", required=True, type=str)
    parser.add_argument("--output_root", "-o", required=True, type=str)
    parser.add_argument("--config_path", "-c", required=True, type=str)

    args = parser.parse_args()
    with open(args.config_path, "r") as file:
        config = yaml.safe_load(file)

    data_name = os.path.basename(args.input_root)

    scaling_constant = config[data_name]["metric_waypoint_spacing"]

    input_paths = glob.glob(os.path.join(args.input_root, "*/"))

    os.makedirs(args.output_root, exist_ok=True)
    output_paths = [
        os.path.join(args.output_root, os.path.basename(os.path.dirname(path)))
        for path in input_paths
    ]

    input_output_pairs = list(zip(input_paths, output_paths))

    with Pool(processes=5) as pool:
        list(
            tqdm.tqdm(
                pool.imap(
                    partial(create_tfrecord, scaling_constant=scaling_constant),
                    input_output_pairs,
                ),
                dynamic_ncols=True,
                total=len(input_output_pairs),
            )
        )
