import time
import chex
import numpy as np
from PIL import Image
from io import BytesIO


def pose_distance(
    position: np.ndarray,
    quaternion: np.ndarray,
    goal_position: np.ndarray,
    goal_quaternion: np.ndarray,
    orientation_weight: float = 1.0,
):
    # Compute quaternion distance
    q1 = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True)
    q2 = goal_quaternion / np.linalg.norm(goal_quaternion, axis=-1, keepdims=True)
    d_quat = 2 * np.arccos(np.abs(np.sum(q1 * q2, axis=-1)))

    # Compute position distance
    d_pos = np.linalg.norm(position - goal_position, axis=-1)
    return d_pos + orientation_weight * d_quat

def close_enough(
    position: np.ndarray,
    quaternion: np.ndarray,
    goal_position: np.ndarray,
    goal_quaternion: np.ndarray,
    orientation_weight: float = 1.0,
):
    # check position
    if (np.abs(position - goal_position) > 1.5).any(): 
        return False
    
    # check quaternion
    q1 = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True)
    q2 = goal_quaternion / np.linalg.norm(goal_quaternion, axis=-1, keepdims=True)
    d_quat = 2 * np.arccos(np.abs(np.sum(q1 * q2, axis=-1)))
    if np.abs(d_quat) > np.pi/ 4: # more than 45 degrees off 
        return False
    
    return True
    

class TrainingTask:
    def __init__(self, goal_file: str):
        # Load goal file as npz
        self.goal_data = np.load(goal_file)
        self.goal_idx = None
        self.last_reset_time = None

        self.timeout = 45.0
        self.threshold = 0.2
        self.is_first = True

    def update(self, position: np.ndarray, quaternion: np.ndarray, crashed: bool):
        if self.goal_idx is None:
            self.select_goal_idx(position, quaternion)

        start_time = time.time()

        current_goal = self.get_goal()

        goal_position = current_goal["position"]
        goal_quaternion = current_goal["orientation"]

        start_time = time.time()
        reached =  close_enough(position, quaternion, goal_position, goal_quaternion)
        timeout = time.time() - self.last_reset_time > self.timeout

        was_first = self.is_first
        self.is_first = False
        return {
            "goal": current_goal,
            "reached_goal": reached,
            "is_first": was_first,
            "is_terminal": (reached or crashed) and not timeout, # effictively is_last and not reached or crashed
            "is_last": reached or crashed or timeout, # this would mean we also need to reset environment! because it was just reset ! 
            "timeout": timeout,
            "crash": crashed,
        }

    def select_goal_idx(self, position: np.ndarray, quaternion: np.ndarray):
        # Find the distance to each point in the dataset, and sample randomly from the top 10
        topk = 25
        goal_positions = self.goal_data["data/position"]
        goal_quaternions = self.goal_data["data/orientation"]

        distances = pose_distance(
            position, quaternion, goal_positions, goal_quaternions
        )
        best_idcs = np.argpartition(distances, topk)[:topk]
        
        probs = np.exp(-distances[best_idcs])
        probs /= np.sum(probs)

        chex.assert_shape(best_idcs, [topk])
        self._goal_base_idx = int(np.random.choice(best_idcs, p=probs))
        self.goal_idx = (
            self._goal_base_idx + int(np.random.exponential() * 10)
        ) % len(self.goal_data["data/position"])
        assert isinstance(self.goal_idx, int), f"goal_idx is {self.goal_idx} ({type(self.goal_idx)})"

        self.last_reset_time = time.time()

    def reset(self, position, quaternion):
        self.is_first = True

        if len(position) == 0:
            start_idx = np.random.randint(0, len(self.goal_data["data/position"]))
            position = self.goal_data["data/position"][start_idx]
            quaternion = self.goal_data["data/orientation"][start_idx]
            
        self.select_goal_idx(position, quaternion)
        return position, quaternion

    def get_goal(self):
        if self.goal_idx is None:
            raise ValueError("Goal not selected yet!")

        # Return the goal image and state
        position = self.goal_data["data/position"][self.goal_idx]
        sample_info = {
            "position": self.goal_data["data/position"][self._goal_base_idx],
            "orientation": self.goal_data["data/orientation"][self._goal_base_idx],
            "offset": np.float32(self.goal_idx - self._goal_base_idx),
        }
        
        return {
            "image": self.goal_data["data/image"][self.goal_idx],
            # "image_bytes": goal_image_bytes_np_array,
            "position": position,
            "orientation": self.goal_data["data/orientation"][self.goal_idx],
            "sample_info": sample_info,
        }

    def reset_timer(self):
        self.last_reset_time = time.time()
