import numpy as np
import os
from PIL import Image
from typing import Any, Iterable

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch.nn.functional as F

VISUALIZATION_IMAGE_SIZE = (120, 160)
IMAGE_ASPECT_RATIO = 4 / 3


def rccar_get_image_path(data_folder: str, f: str, time: int):
    return os.path.join(data_folder, f, f"{str(time).zfill(4)}.jpg")


def recon_get_image_path(
    data_folder: str, f: str, time: int, image_type: str = "rgb_left"
):
    image_path = os.path.join(data_folder, f, str(time), f"{image_type}.jpeg")
    # check if image exists
    if not os.path.exists(image_path):
        image_path = os.path.join(data_folder, f, time, f"{image_type}.jpg")
    return image_path


def raw_gs_get_image_path(data_folder: str, f: str, time: int):
    f = f"imgall_360view_{f}"
    image_folder = os.path.join(data_folder, f)
    compressed_name = f.split("_")[-1]
    if f[-1] == "F":
        img_name = f"img_{compressed_name[:-1]}_{time + 1}F.jpg"
    else:
        img_name = f"img_{compressed_name}_{time + 1}.jpg"
    return os.path.join(image_folder, img_name)


def get_image_path(data_folder: str, f: str, time: int):
    return os.path.join(data_folder, f, f"{str(time)}.jpg")

def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
    data_ext = {
        "image": ".jpg",
        "clip": "_L14.pt",
        "vint": "_vint_feat.pt",
    }
    return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")


image_path_func = {
    "rccar": rccar_get_image_path,
    "recon": recon_get_image_path,
    "go_stanford": get_image_path,
    "raw_go_stanford": raw_gs_get_image_path,
    "scand": get_image_path,
    "scand_cleaned": get_image_path,
    "racer": get_image_path,
    "uw": get_image_path,
    "uw_cleaned": get_image_path,
    "tartan": get_image_path,
    "tartan_cleaned": get_image_path,
    "carla": get_image_path,
    "arl": get_image_path,
    "husky": get_image_path,
    "carla_cil": get_image_path,
    "carla_intvns": get_image_path,
    "sidewalk_land_individual": get_image_path,
    "eastlot": get_image_path
}

# average distances between waypoints obtained from the datasets
metric_waypoint_spacings = {
    "rccar": 0.06,
    "recon": 0.25,
    "go_stanford": 0.12,
    "raw_go_stanford": 0.12,
    "scand": 0.38,
    "scand_cleaned": 0.38,
    "racer": 0.38,
    "uw": 0.35,
    "uw_cleaned": 0.37,
    "tartan": 0.72,
    "tartan_cleaned": 0.79,
    "carla": 1.56,
    "arl": 1.53,
    "husky": 0.19,
    "carla_cil": 1.27,
    "carla_intvns": 1.39,
    "sidewalk_land_individual": 0.63,
    "eastlot": 1.37
}


def yaw_rotmat(yaw):
    return np.array(
        [
            [np.cos(yaw), -np.sin(yaw), 0.0],
            [np.sin(yaw), np.cos(yaw), 0.0],
            [0.0, 0.0, 1.0],
        ],
        dtype=object,  # get rid of warning
    )


def rotate_to_local(positions, curr_pos, curr_yaw):
    rotmat = yaw_rotmat(curr_yaw)
    if positions.shape[-1] == 2:
        rotmat = rotmat[:2, :2]
    elif positions.shape[-1] == 3:
        pass
    else:
        raise ValueError

    return (positions - curr_pos).dot(rotmat)


def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor:
    num_params = waypoints.shape[1]
    origin = torch.zeros(1, num_params)
    prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0)
    deltas = waypoints - prev_waypoints
    if num_params > 2:
        return calculate_sin_cos(deltas)
    return deltas


def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor:
    assert waypoints.shape[1] == 3
    angle_repr = torch.zeros_like(waypoints[:, :2])
    angle_repr[:, 0] = torch.cos(waypoints[:, 2])
    angle_repr[:, 1] = torch.sin(waypoints[:, 2])
    return torch.concat((waypoints[:, :2], angle_repr), axis=1)

def img_path_to_data(
    path: str, transform: transforms, aspect_ratio: float = IMAGE_ASPECT_RATIO
) -> torch.Tensor:
    """
    Load an image from a path and transform it
    Args:
        path (str): path to the image
        transform (transforms): transform to apply to the image
        aspect_ratio (float): aspect ratio to crop the image to
    Returns:
        torch.Tensor: transformed image
    """
    img = Image.open(path)
    w, h = img.size
    if w > h:
        img = TF.center_crop(img, (h, int(h * aspect_ratio)))  # crop to the right ratio
    else:
        img = TF.center_crop(img, (int(w / aspect_ratio), w))
    viz_img = TF.resize(img, VISUALIZATION_IMAGE_SIZE)
    viz_img = TF.to_tensor(viz_img)
    transf_img = transform(img)
    return viz_img, transf_img

class RandomizedClassBalancer:
    def __init__(self, classes: Iterable) -> None:
        """
        A class balancer that will sample classes randomly, but will prioritize classes that have been sampled less
        Args:
            classes (Iterable): The classes to balance
        """
        self.counts = {}
        for c in classes:
            self.counts[c] = 0

    def sample(self, class_filter_func=None) -> Any:
        """
        Sample the softmax of the negative logits to prioritize classes that have been sampled less
        """
        if class_filter_func is None:
            keys = list(self.counts.keys())
        else:
            keys = [k for k in self.counts.keys() if class_filter_func(k)]
        if len(keys) == 0:
            return None  # no valid classes to sample
        values = [-(self.counts[k] - min(self.counts.values())) for k in keys]
        p = F.softmax(torch.Tensor(values), dim=0).detach().cpu().numpy()
        choice_i = np.random.choice(list(range(len(keys))), p=p)
        choice = keys[choice_i]
        self.counts[choice] += 1
        return choice

    def __str__(self) -> str:
        string = ""
        for c in self.counts:
            string += f"{c}: {self.counts[c]}\n"
        return string