# generic imports
import time
import numpy as np
from PIL import Image
import os
import sys
import io
from dataclasses import dataclass
import tensorflow as tf
import time
import atexit
import tensorflow_datasets as tfds
import logging
import ipdb
import argparse
# import cv2

from multinav.deploy.train.model_config import get_config

import dlimp
# ros imports
from multinav.deploy.common.trainer_bridge_common import (
    make_action_config,
    task_data_format,
)

from multinav.deploy.train.agent import Agent, RandomAgent
from multinav.deploy.robot.tasks.goal_task import TrainingTask

# custom imports
from agentlace.action import ActionClient
from agentlace.data.rlds_writer import RLDSWriter
from agentlace.data.tf_agents_episode_buffer import EpisodicTFDataStore

# jax & jaxrl
import jax
from absl import logging as absl_logging

IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406])
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225])

WAYPOINT_SPACING = 0.15
ANGLE_SCALE = 1 # np.pi / 9
X_OFFSET = -1 

BUFFER_LEN = 8

MAX_TRAJ_LEN = 100  # 3 times a second * 30 seconds = 90 long
STEPS_TRY = 60
GOAL_DIST = STEPS_TRY // 2  # 4 - 10 # normal around this with 5 std

def normalize_image(image: tf.Tensor) -> tf.Tensor:
    """
    Normalize the image to be between 0 and 1
    """
    return (tf.cast(image, tf.float32) / 255.0 - IMAGENET_MEAN) / IMAGENET_STD


# from dataset_transforms.py 
def get_relative_position(base_position, base_yaw, goal_position, is_negative):
    goal_dist_threshold = 0.2
    if base_yaw.ndim > 0:
        base_yaw = tf.squeeze(base_yaw, axis=-1)
    goal_position = (goal_position - base_position) / WAYPOINT_SPACING
    rotation_matrix = tf.stack(
        [
            tf.stack([tf.cos(base_yaw), tf.sin(base_yaw)], axis=0),
            tf.stack([-tf.sin(base_yaw), tf.cos(base_yaw)], axis=0),
        ],
        axis=0,
    )
    goal_vector = tf.matmul(rotation_matrix, goal_position[..., None])[..., 0]
    random_goal_vectors = tf.random.normal(
        goal_vector.shape, mean=0, stddev=10, dtype=goal_vector.dtype
    )

    goal_vector = tf.where(is_negative, random_goal_vectors, goal_vector)

    goal_vector_norm = tf.maximum(
        tf.norm(goal_vector, axis=-1, keepdims=True), goal_dist_threshold
    )
    goal_vector_magdir = tf.concat(
        [
            goal_vector / goal_vector_norm,
            1 / goal_vector_norm,
        ],
        axis=-1,
    )
    return goal_vector, goal_vector_magdir

# from goal_task.py
def _yaw(quat):
    return np.arctan2(
        2.0 * (quat[3] * quat[2] + quat[0] * quat[1]),
        1.0 - 2.0 * (quat[1] ** 2 + quat[2] ** 2),
    )

class Model():
    def __init__(self, server_ip: str,
                 save_dir,
                 checkpoint_dir,
                 checkpoint_step,
                 max_time,
                 goal_dir,
                 action_type: str,
                 is_sim: bool,
                 keep_obs_buffer: bool,
                deterministic: str, 
                step_by_one: str):
        self.max_time = max_time
        self.start_time = time.time()
        self.last_saved = self.start_time
        self.tick_rate = 3
        self.action_type = action_type
        self.is_sim = is_sim
        self.keep_obs_buffer = keep_obs_buffer
        self.deterministic = True if deterministic == "True" else False
        self.step_by_one = True if step_by_one == "True" else False
        
        if self.action_type != "random" and self.action_type != "teleop":
            self.config = get_config(self.action_type)

        data_dir = save_dir

        existing_folders = [0] + [int(folder.split('.')[-1])
                                  for folder in os.listdir(data_dir)]
        latest_version = max(existing_folders)

        self.version = f"0.0.{1 + latest_version}"
        self.datastore_path = f"{data_dir}/{self.version}"
        os.makedirs(self.datastore_path)

        # self.version = "0.0.1"
        # self.datastore_path = tf.io.gfile.join(data_dir, self.version)
        # os.makedirs(self.datastore_path)

        # self.datastore_path = data_dir

        self.action_client = ActionClient(
            server_ip,
            make_action_config(),
        )

        # setting up rlds writer
        data_spec = task_data_format()

        self.writer = RLDSWriter(
            dataset_name="test",
            data_spec=data_spec,
            data_directory=self.datastore_path,
            version=self.version, # ADD IN MANUALLY I GUESS IDK 
            max_episodes_per_file=100,
        )
        atexit.register(self.writer.close)  # so it SAVES on exit

        self.data_store = EpisodicTFDataStore(
            capacity=1000,
            data_spec=data_spec,
            rlds_logger=self.writer
        )
        print("Datastore set up")

        # setting up model
        self.rng = jax.random.PRNGKey(seed=42)

        if self.action_type == "random":
            self.agent = RandomAgent()
        elif self.action_type == "teleop":
            self.agent = None
        else:
            self.agent = Agent(self.config, 42)
            if self.action_type != "transformer":
                # just use random args for transformer for now just for sanity
                self.agent.load_checkpoint(checkpoint_dir, checkpoint_step)

        if self.keep_obs_buffer:
            from collections import deque
            # only keep last 8 observations
            self.obs_buffer = deque(maxlen=BUFFER_LEN)

        # Get Goal Dataset
        self.task = TrainingTask(goal_dir, self.step_by_one)

    def run(self):
        self.loop_time = 1 / self.tick_rate
        start_time = time.time()

        self.just_crashed = False
        self.latest_action = np.array([0, 0], dtype=np.float32)
        self.traj_len = 0
        self.curr_goal = None  # has image, position, and steps
        self.trajs = 0
        self.timeouts = 0
        self.crashs = 0
        self.start_pos = None
        self.reached = True  # so we start with a new goal

        while True:
            new_start_time = time.time()
            elapsed = new_start_time - start_time
            if elapsed < self.loop_time:
                time.sleep(self.loop_time - elapsed)
            start_time = time.time()

            self.tick()

    def int_image(self, img):
        return np.asarray((img * IMAGENET_STD + IMAGENET_MEAN) * 255, dtype=np.uint8)

    def get_info_with_context(self):
        if len(self.obs_buffer) != self.model.context_size + 1:
            # print("not enough observations yet for context window")
            return None, None

        obs_list = list(self.obs_buffer)  # each of these has an obs, goal.

        obs_imgs = [
            Image.open(io.BytesIO(obs[0]["image"]))
            for obs in obs_list
        ]
        goal_img = Image.fromarray(obs_list[-1][1]["int_image"])

        resized_obs_img = np.array(tf.image.resize(obs_imgs, [85, 64]))
        resized_goal_img = np.array(tf.image.resize([goal_img], [85, 64]))

        Image.fromarray(np.cast[np.uint8](np.clip(np.transpose(resized_obs_img[0, -3:], (1, 2, 0)) * IMAGENET_STD + IMAGENET_MEAN, 0, 1) * 255)).save("test.png")
        Image.fromarray(np.cast[np.uint8](np.clip(np.transpose(resized_goal_img[0], (1, 2, 0)) * IMAGENET_STD + IMAGENET_MEAN, 0, 1) * 255)).save("goal.png")

        assert obs_imgs.shape == (1, 3*(self.model.context_size + 1), 64, 85), obs_imgs.shape
        assert goal_img.shape == (1, 3, 64, 85), goal_img.shape

        return obs_imgs, goal_img

    def save(self, obs):
        self.curr_obs_img = tf.io.decode_image(
            obs["image"], expand_animations=False)
        obs["image"] = tf.convert_to_tensor(
            obs["image"])  # get it as raw byte array
        obs["action_state_source"] = tf.convert_to_tensor(
            obs["action_state_source"])

        goal_image_bytes_io = io.BytesIO()
        Image.fromarray(np.array(self.curr_goal["int_image"])).save(
            goal_image_bytes_io, format='JPEG')  # want unnormalized version
        goal_image_bytes = tf.constant(
            goal_image_bytes_io.getvalue(), dtype=tf.string)

        obs["goal"] = {
            "image": goal_image_bytes,
            "position": self.curr_goal["position"], 
            "orientation": self.curr_goal["orientation"],
            "reached": self.task_result["reached_goal"],
            "sample_info": self.curr_goal["sample_info"],
        }

        formatted_obs = {
            "observation": obs,
            "action": tf.concat([obs["last_action_linear"], obs["last_action_angular"]], axis=0),
            "is_first": self.task_result["is_first"],
            "is_last": self.task_result["is_last"],
            "is_terminal": self.task_result["is_terminal"],
        }
        self.data_store.insert(formatted_obs)
        self.last_saved = time.time()

    def reset_goal(self, pos, quat, reached):
        # new goal sampled; new start position returned
        position, orientation = self.task.reset(pos, quat, reached)
        self.curr_goal = self.task.get_goal()
        
        # decode imaeg
        if type(self.curr_goal["image"]) == str:
            self.curr_goal["image"] = tf.io.decode_image(
                self.curr_goal["image"], expand_animations=False)

        self.curr_goal["int_image"] = self.curr_goal["image"]
        self.curr_goal["image"] = normalize_image(self.curr_goal["image"])
        self.curr_goal["yaw"] = _yaw(self.curr_goal["orientation"])
        
        goal_jpeg = io.BytesIO()
        Image.fromarray(self.curr_goal["int_image"]).save(goal_jpeg, format="JPEG")
        goal_img_bytes = np.frombuffer(goal_jpeg.getvalue(), dtype=np.uint8)
        self.curr_goal["img_bytes"] = tf.constant(goal_jpeg.getvalue(), dtype = tf.string)

        goal_pose = {"position": self.curr_goal["position"].astype(
            float),  "orientation": self.curr_goal["orientation"].astype(float) , "image": goal_img_bytes}

        self.action_client.act("new_goal", goal_pose)
    

    def reset(self, obs):
        self.just_crashed = True
        self.traj_len = 0

        if obs["keepout"]:
            twists = [np.array([-0.3, 0]), np.array([0.0, -0.5])]
            time_per_twist = [3.0, 1.2]
        elif obs["crash_left"]:
            twists = [np.array([-0.2, 0]), np.array([0.0, -0.5])]
            time_per_twist = [1.0, 1.0]
        elif obs["crash_right"]:
            twists = [np.array([-0.2, 0]), np.array([0.0, 0.5])]
            time_per_twist = [1.0, 1.0]
        elif obs["crash_center"]:
            twists = [np.array([-0.2, 0]), np.random.choice([-1, 1]) * np.array([0.0, 0.5])]
            time_per_twist = [1.0, 1]
       
        res = {"running": True, "reason": "starting"}

        while res is None or res["running"]:
            res = self.action_client.act(
                "reset", {"twists": twists, "time_per_twist": time_per_twist})
            time.sleep(0.3)

    def take_action(self, obs):
        self.just_crashed = False
        self.traj_len += 1

        self.rng, _ = jax.random.split(self.rng)

        curr_img = self.curr_obs_img
        curr_img = (tf.cast(curr_img, tf.float32) /
                    255.0 - IMAGENET_MEAN) / IMAGENET_STD
        curr_img = np.array(curr_img)

        start = time.time()
        if self.action_type == "random":
            action = self.agent.rand_action()
        elif self.action_type == "teleop":
            pass
        elif self.action_type == "transformer":
            context_window_info = self.get_info_with_context()
            if context_window_info is not None:
                action = self.agent.predict(context_window_info)
            else:
                print("Not enough context")
                action = [0., 0.]
        else:
            resized_obs_img = np.array(tf.image.resize(curr_img, [64, 64]))
            resized_goal_img = np.array(tf.image.resize(self.curr_goal["image"], [64, 64]))

            assert resized_obs_img.shape == (64, 64, 3), f"actual obs img size {resized_obs_img.shape}"
            assert resized_goal_img.shape == (64, 64, 3), f"actual goal img size {resized_goal_img.shape}"

            action = self.agent.predict(
                obs_image=resized_obs_img, goal_image=resized_goal_img, random = not self.deterministic)
            action = np.array(
                [WAYPOINT_SPACING * (action[0] - X_OFFSET), action[1] * ANGLE_SCALE])

        if not self.action_type == "teleop":
            self.latest_action = action
            self.action_client.act("action_vw", np.array(
                [action[0], action[1]]))  # 0, 0, 0, 0,

    def tick(self):
        obs = self.action_client.obs()

        if obs is None:
            return 

        if self.curr_goal is None:
            self.reset_goal(obs["position"], obs["orientation"], False)
            self.curr_obs_img = tf.io.decode_image(
                obs["image"], expand_animations=False)
            return 
            
        if self.keep_obs_buffer:
            self.obs_buffer.append((obs, self.curr_goal))

        self.task_result = self.task.update(
            obs["position"], obs["orientation"], obs["crash"])
        self.save(obs)

        if self.task_result["reached_goal"] or (self.task_result["crash"] and not self.just_crashed) or self.task_result["timeout"]: 
            
            print(f"\nResetting, reached {self.task_result['reached_goal']}" +
                    f" timeout {self.task_result['timeout']} " +
                    f" crash {obs['crash']}" +
                    f" keepout {obs['keepout']} ")
            
            if (self.task_result["crash"] and not self.just_crashed) or self.task_result["timeout"]:
                self.reset(obs)
                obs = self.action_client.obs()
                while obs is None:
                    obs = self.action_client.obs()

            self.trajs += 1
            self.crashs += int(self.task_result['crash'])
            self.timeouts += int(self.task_result['timeout'])
            self.reset_goal(obs["position"], obs["orientation"], self.task_result["reached_goal"])
            self.task.reset_timer()

        else:
            self.take_action(obs)

        if self.max_time is not None:
            if time.time() - self.start_time > self.max_time:
                self.action_client.act(
                    "action_vw", np.array([0, 0])) 
                print(
                    f"Killing model deployment after {time.time() - self.start_time} seconds.")
                print(
                    f"Completed {self.trajs} trajectories with {self.crashs} crashes and {self.timeouts} timeouts.")
                sys.exit()


if __name__ == "__main__":

    tf.get_logger().setLevel("WARNING")
    logging.basicConfig(level=logging.WARNING)
    absl_logging.set_verbosity("WARNING")

    parser = argparse.ArgumentParser(description='My Python Script')
    parser.add_argument('--data_save_dir', type=str,
                        help='Where to save collected data')
    parser.add_argument('--checkpoint_load_dir', type=str,
                        help='Where to load model checkpoint from')
    parser.add_argument('--checkpoint_load_step', type=int,
                        help='Which checkpoint to load')
    parser.add_argument('--server_ip', type=str,
                        help='Where to connect to robot server')
    parser.add_argument('--max_time', type=int,
                        help='How long to run for')
    parser.add_argument('--goal_dir', type=str,
                        help="npz to load goals from ")
    parser.add_argument('--action_type', type=str,
                        help = "What type of model do you wanna use?")
    parser.add_argument('--deterministic', type= str,
                        help = "Should the action be deterministic")
    parser.add_argument('--step_by_one', type=str,
                        help="If the goal should be selected as NEXT in loop")
    args = parser.parse_args()

    Model(server_ip=args.server_ip,  # "localhost",
          save_dir=args.data_save_dir,
          checkpoint_dir=args.checkpoint_load_dir,
          checkpoint_step=args.checkpoint_load_step,
          max_time=args.max_time,
          goal_dir=args.goal_dir,
          action_type=args.action_type,  # 
          is_sim=False,
          keep_obs_buffer=True,
          deterministic = args.deterministic,
          step_by_one = args.step_by_one,
          ).run()