# generic imports
import time
import numpy as np
from PIL import Image
import os
import sys
import io
from dataclasses import dataclass
import time
import atexit
import logging
import argparse
import torch
# import cv2

from collections import deque
from deployment.src.utils import load_model, transform_images

# custom imports
from agentlace.action import ActionClient, ActionConfig

import time
import numpy as np
from PIL import Image
from io import BytesIO
import yaml


action_config = ActionConfig(
    port_number=1111,
    action_keys=["action_vw", "action_pose", "reset", "dock", "undock", "new_goal", "q_vals"],
    observation_keys=[
        # Raw sensor
        "image",
        "imu_accel",
        "imu_gyro",
        "odom_pose",
        "linear_velocity",
        "angular_velocity",
        "cliff",
        "crash",
        "stall",
        "keepout",
        "position",
        "orientation",
        "pose_std",
        "action_state_source",
        "last_action_linear",
        "last_action_angular",
    ]
)

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])

WAYPOINT_SPACING = 0.15
ANGLE_SCALE = 1 # np.pi / 9
X_OFFSET = -1 

MAX_TRAJ_LEN = 100  # 3 times a second * 30 seconds = 90 long
STEPS_TRY = 60
GOAL_DIST = STEPS_TRY // 2  # 4 - 10 # normal around this with 5 std

## START goal_task.py
def pose_distance(
    position: np.ndarray,
    quaternion: np.ndarray,
    goal_position: np.ndarray,
    goal_quaternion: np.ndarray,
    orientation_weight: float = 1.0,
):
    # Compute quaternion distance
    q1 = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True)
    q2 = goal_quaternion / np.linalg.norm(goal_quaternion, axis=-1, keepdims=True)
    d_quat = 2 * np.arccos(np.abs(np.sum(q1 * q2, axis=-1)))

    # Compute position distance
    d_pos = np.linalg.norm(position - goal_position, axis=-1)
    return d_pos + orientation_weight * d_quat

def close_enough(
    position: np.ndarray,
    quaternion: np.ndarray,
    goal_position: np.ndarray,
    goal_quaternion: np.ndarray,
    orientation_weight: float = 1.0,
):
    # check position
    if (np.abs(position - goal_position) > 1.5).any(): 
        return False
    
    # check quaternion
    q1 = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True)
    q2 = goal_quaternion / np.linalg.norm(goal_quaternion, axis=-1, keepdims=True)
    d_quat = 2 * np.arccos(np.abs(np.sum(q1 * q2, axis=-1)))
    if np.abs(d_quat) > np.pi/ 4: # more than 45 degrees off 
        return False
    return True
    

class TrainingTask:
    def __init__(self, goal_file: str, step_by_one: bool):
        # Load goal file as npz
        self.goal_data = np.load(goal_file)
        self.goal_idx = None
        self.last_reset_time = None

        self.timeout = 300.0
        self.threshold = 0.2
        self.is_first = True
        self.step_by_one = step_by_one

    def update(self, position: np.ndarray, quaternion: np.ndarray, crashed: bool):
        if self.goal_idx is None:
            self.select_goal_idx(position, quaternion)

        current_goal = self.get_goal()
        goal_position = current_goal["position"]
        goal_quaternion = current_goal["orientation"]

        reached =  close_enough(position, quaternion, goal_position, goal_quaternion)
        timeout = time.time() - self.last_reset_time > self.timeout

        was_first = self.is_first
        self.is_first = False
        return {
            "goal": current_goal,
            "reached_goal": reached,
            "is_first": was_first,
            "is_terminal": (reached or crashed) and not timeout, # effictively is_last and not reached or crashed
            "is_last": reached or crashed or timeout, # this would mean we also need to reset environment! because it was just reset ! 
            "timeout": timeout,
            "crash": crashed,
        }

    def select_goal_idx(self, position: np.ndarray, quaternion: np.ndarray):
        if self.step_by_one:
            if self.goal_idx is None:
                self.goal_idx = 1
                self._goal_base_idx = 0
            else:
                self._goal_base_idx = self.goal_idx
                self.goal_idx += 1

            print("Goal IDX is now", self.goal_idx)
            if self.goal_idx >= self.goal_data["data/position"].shape[0]:
                print("ERROR: out of goals")

        else: # Find the distance to each point in the dataset, and sample randomly from the top 10
            topk = 25
            goal_positions = self.goal_data["data/position"]
            goal_quaternions = self.goal_data["data/orientation"]

            distances = pose_distance(
                position, quaternion, goal_positions, goal_quaternions
            )
            best_idcs = np.argpartition(distances, topk)[:topk]
            
            probs = np.exp(-distances[best_idcs])
            probs /= np.sum(probs)

            chex.assert_shape(best_idcs, [topk])
            self._goal_base_idx = int(np.random.choice(best_idcs, p=probs))
            self.goal_idx = (
                self._goal_base_idx + int(np.random.exponential() * 10)
            ) % len(self.goal_data["data/position"])
            assert isinstance(self.goal_idx, int), f"goal_idx is {self.goal_idx} ({type(self.goal_idx)})"

        self.last_reset_time = time.time()

    def reset(self, position, quaternion):
        self.is_first = True

        if len(position) == 0:
            start_idx = np.random.randint(0, len(self.goal_data["data/position"]))
            position = self.goal_data["data/position"][start_idx]
            quaternion = self.goal_data["data/orientation"][start_idx]
            
        self.select_goal_idx(position, quaternion)
        return position, quaternion

    def get_goal(self):
        if self.goal_idx is None:
            raise ValueError("Goal not selected yet!")

        # Return the goal image and state
        position = self.goal_data["data/position"][self.goal_idx]
        sample_info = {
            "position": self.goal_data["data/position"][self._goal_base_idx],
            "orientation": self.goal_data["data/orientation"][self._goal_base_idx],
            "offset": np.float32(self.goal_idx - self._goal_base_idx),
        }
        
        return {
            "image": self.goal_data["data/image"][self.goal_idx],
            # "image_bytes": goal_image_bytes_np_array,
            "position": position,
            "orientation": self.goal_data["data/orientation"][self.goal_idx],
            "sample_info": sample_info,
        }

    def reset_timer(self):
        self.last_reset_time = time.time()


# END goal_task.py

def normalize_image(image: np.ndarray) -> np.ndarray:
    return (image / 255.0 - IMAGENET_MEAN) / IMAGENET_STD


# from goal_task.py
def _yaw(quat):
    return np.arctan2(
        2.0 * (quat[3] * quat[2] + quat[0] * quat[1]),
        1.0 - 2.0 * (quat[1] ** 2 + quat[2] ** 2),
    )

class Model():
    def __init__(self, server_ip: str,
                 save_dir,
                 max_time,
                 goal_dir,
                 is_sim: bool,
                 keep_obs_buffer: bool):
        self.max_time = max_time
        self.start_time = time.time()
        self.last_saved = self.start_time
        self.tick_rate = 3
        self.is_sim = is_sim
        self.keep_obs_buffer = keep_obs_buffer

        data_dir = save_dir

        self.action_client = ActionClient(
            server_ip,
            action_config,
        )

        with open("train/config/vint.yaml", "r") as f:
            config = yaml.safe_load(f)
        self.model = load_model(
            "vint.pth",
            config,
            "cuda",
        )
        self.model.eval()

        # only keep last 8 observations
        self.obs_buffer = deque(maxlen=self.model.context_size + 1)

        # Get Goal Dataset
        self.task = TrainingTask(goal_dir, True)
        

    def run(self):
        self.loop_time = 1 / self.tick_rate
        start_time = time.time()

        self.just_crashed = False
        self.latest_action = np.array([0, 0], dtype=np.float32)
        self.traj_len = 0
        self.curr_goal = None  # has image, position, and steps
        self.trajs = 0
        self.timeouts = 0
        self.crashs = 0
        self.start_pos = None
        self.reached = True  # so we start with a new goal

        while True:
            new_start_time = time.time()
            elapsed = new_start_time - start_time
            if elapsed < self.loop_time:
                time.sleep(self.loop_time - elapsed)
            start_time = time.time()

            self.tick()

    def int_image(self, img):
        return np.asarray((img * IMAGENET_STD + IMAGENET_MEAN) * 255, dtype=np.uint8)

    def get_info_with_context(self):
        if len(self.obs_buffer) != self.model.context_size + 1:
            # print("not enough observations yet for context window")
            return None, None


        obs_list = list(self.obs_buffer)  # each of these has an obs, goal.

        obs_imgs = [
            Image.open(io.BytesIO(obs[0]["image"]))
            for obs in obs_list
        ]
        goal_img = Image.fromarray(obs_list[-1][1]["int_image"])

        obs_imgs = transform_images(obs_imgs, [85, 64])
        goal_img = transform_images([goal_img], [85, 64])

        Image.fromarray(np.cast[np.uint8](np.clip(np.transpose(obs_imgs[0, -3:], (1, 2, 0)) * IMAGENET_STD + IMAGENET_MEAN, 0, 1) * 255)).save("test.png")
        Image.fromarray(np.cast[np.uint8](np.clip(np.transpose(goal_img[0], (1, 2, 0)) * IMAGENET_STD + IMAGENET_MEAN, 0, 1) * 255)).save("goal.png")

        assert obs_imgs.shape == (1, 3*(self.model.context_size + 1), 64, 85), obs_imgs.shape
        assert goal_img.shape == (1, 3, 64, 85), goal_img.shape

        return obs_imgs, goal_img

    def reset_goal(self, pos, quat):
        # new goal sampled; new start position returned
        position, orientation = self.task.reset(pos, quat)
        self.curr_goal = self.task.get_goal()
        
        if type(self.curr_goal["image"]) == str:  # need to decode it!
            self.curr_goal["image"] = Image.open(io.BytesIO(self.curr_goal["image"]))

        self.curr_goal["int_image"] = self.curr_goal["image"]
        self.curr_goal["image"] = normalize_image(self.curr_goal["image"])

        self.curr_goal["yaw"] = _yaw(self.curr_goal["orientation"])
        
        goal_jpeg = io.BytesIO()
        Image.fromarray(self.curr_goal["int_image"]).save(goal_jpeg, format="JPEG")
        goal_img_bytes = np.frombuffer(goal_jpeg.getvalue(), dtype=np.uint8)

        # goal_pose = {"position": self.curr_goal["position"].astype(float),  "orientation": self.curr_goal["orientation"].astype(float)}
        goal_pose = {"position": self.curr_goal["position"].astype(
            float),  "orientation": self.curr_goal["orientation"].astype(float) , "image": goal_img_bytes}

        self.action_client.act("new_goal", goal_pose)
    

    def reset(self):
        self.just_crashed = True
        self.traj_len = 0

        print("resetting")
        # want reset to be blocking, does it keep on returning?
        res = {"running": True, "reason": "starting"}
        while res["running"]:
            res = self.action_client.act(
                "reset", {"position": 0, "orientation": 0})
        print("reset done")
            
        # shouldn't be penalized for a reset! give us another chance bestie.
        self.task.reset_timer()

        if self.is_sim:
            marker_position = [self.curr_goal["position"]
                               [0], self.curr_goal["position"][1],  2]
            self.action_client.act(
                "move_marker", {"position": marker_position})

    def take_action(self, obs):
        self.just_crashed = False
        self.traj_len += 1

        start = time.time()

        obs, goal = self.get_info_with_context()
        if obs is not None:
            with torch.no_grad():
                dist, action = self.model(obs.cuda(), goal.cuda())
                action = action.cpu().numpy()
                action = action[0, 2, :2] # 2 is the middle waypoint
                action[0] = action[0] * 0.05
                action[1] = action[1] * 1
            # print(obs.mean(), obs.std(), goal.mean(), goal.std(), action)
        else:
            action = np.zeros((2,))
        
        self.action_client.act("action_vw", action)

    def tick(self):
        obs = self.action_client.obs()

        start_time = time.time()
        # print("Tick")
        if obs is not None:
            # print("\tobs exists")

            if self.curr_goal is None:
                self.reset_goal(obs["position"], obs["orientation"])
                self.curr_obs_img = np.asarray(Image.open(io.BytesIO(obs["image"])))
            else:
                if self.keep_obs_buffer:
                    self.obs_buffer.append((obs, self.curr_goal))

                # print('buf', len(self.obs_buffer))
                self.task_result = self.task.update(
                    obs["position"], obs["orientation"], obs["crash"] or obs["keepout"])
                # print("\tAfter task update", time.time() - start_time)

                if self.task_result["reached_goal"] or (self.task_result["crash"] and not self.just_crashed) or self.task_result["timeout"]: 
                    
                    print(f"\nResetting, reached {self.task_result['reached_goal']}" +
                          f" timeout {self.task_result['timeout']} " +
                          f" crash {obs['crash']}" +
                          f" keepout {obs['keepout']} ")
                    
                    if (self.task_result["crash"] and not self.just_crashed) or self.task_result["timeout"]:
                        self.reset()
                        obs = self.action_client.obs()
                        while obs is None:
                            obs = self.action_client.obs()

                    self.trajs += 1
                    self.crashs += int(self.task_result['crash'])
                    self.timeouts += int(self.task_result['timeout'])
                    self.reset_goal(obs["position"], obs["orientation"])
                    self.task.reset_timer()

                else:
                    self.take_action(obs)

        if self.max_time is not None:
            if time.time() - self.start_time > self.max_time:
                self.action_client.act(
                    "action_vw", np.array([0, 0]))  # 0, 0, 0, 0,
                print(
                    f"Killing model deployment after {time.time() - self.start_time} seconds.")
                print(
                    f"Completed {self.trajs} trajectories with {self.crashs} crashes and {self.timeouts} timeouts.")
                sys.exit()


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARNING)

    parser = argparse.ArgumentParser(description='My Python Script')
    parser.add_argument('--data_save_dir', type=str,
                        help='Where to save collected data')
    parser.add_argument('--server_ip', type=str,
                        help='Where to connect to robot server')
    parser.add_argument('--max_time', type=int,
                        help='How long to run for')
    parser.add_argument('--goal_dir', type=str,
                        help="npz to load goals from ")
    args = parser.parse_args()

    Model(server_ip=args.server_ip,  # "localhost",
          save_dir=args.data_save_dir,
          max_time=args.max_time,
          goal_dir=args.goal_dir,
          is_sim=False,
          keep_obs_buffer=True,
          ).run()