import logging
import os
import time
from agentlace.action import ActionClient
from agentlace.data.data_store import QueuedDataStore
from agentlace.trainer import TrainerClient
import chex
import numpy as np
from PIL import Image
from io import BytesIO
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') 

from collections import deque

import argparse

from multinav.deploy.train.agent import Agent
from multinav.deploy.common.trainer_bridge_common import (
    make_action_config,
    make_trainer_config,
)

from multinav.deploy.robot.tasks.goal_task import TrainingTask
from multinav.deploy.train.data_config import get_config
from pprint import pprint

IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406])
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225])

WAYPOINT_SPACING = 0.25
ANGLE_SCALE = 1 
X_OFFSET = -1 

def normalize_image(image: tf.Tensor) -> tf.Tensor:
    """
    Normalize the image to be between 0 and 1
    """
    return (tf.cast(image, tf.float32) / 255.0 - IMAGENET_MEAN) / IMAGENET_STD

# from dataset_transforms.py 
def get_relative_position(base_position, base_yaw, goal_position, is_negative):
    goal_dist_threshold = 0.2
    if base_yaw.ndim > 0:
        base_yaw = tf.squeeze(base_yaw, axis=-1)
    goal_position = (goal_position - base_position) / WAYPOINT_SPACING
    rotation_matrix = tf.stack(
        [
            tf.stack([tf.cos(base_yaw), tf.sin(base_yaw)], axis=0),
            tf.stack([-tf.sin(base_yaw), tf.cos(base_yaw)], axis=0),
        ],
        axis=0,
    )
    goal_vector = tf.matmul(rotation_matrix, goal_position[..., None])[..., 0]
    random_goal_vectors = tf.random.normal(
        goal_vector.shape, mean=0, stddev=10, dtype=goal_vector.dtype
    )

    goal_vector = tf.where(is_negative, random_goal_vectors, goal_vector)

    goal_vector_norm = tf.maximum(
        tf.norm(goal_vector, axis=-1, keepdims=True), goal_dist_threshold
    )
    goal_vector_magdir = tf.concat(
        [
            goal_vector / goal_vector_norm,
            1 / goal_vector_norm,
        ],
        axis=-1,
    )
    return goal_vector, goal_vector_magdir

# from goal_task.py
def _yaw(quat):
    return np.arctan2(
        2.0 * (quat[3] * quat[2] + quat[0] * quat[1]),
        1.0 - 2.0 * (quat[1] ** 2 + quat[2] ** 2),
    )


class TrainActor:
    def __init__(self, trainer_ip: str, robot_ip: str, goal_dir: str, seed:int):
        # Set up Agentlace clients
        self.local_data_store = QueuedDataStore(capacity=10)
        self.trainer = TrainerClient(
            "online_data",
            trainer_ip,
            make_trainer_config(),
            self.local_data_store,
            wait_for_server=True,
        )

        self.robot = ActionClient(
            robot_ip,
            make_action_config(),
        )
        self.tick_rate = 3
        self.crash_history = deque(maxlen=10)

        # Task/actor setup
        self.model_config = self.trainer.request("get-model-config", {})
        print("Got model config from trainer")
        pprint(self.model_config)
        self.task = TrainingTask(goal_dir, step_by_one = False)
        
        self.agent = Agent(self.model_config, seed)
        
    def run(self):
        received_params = False
        def _update_actor(data):
            nonlocal received_params
            received_params = True
            self.agent.update_params(data["params"])

        self.trainer.recv_network_callback(_update_actor)
        self.trainer.start_async_update(interval=1)

        loop_time = 1 / self.tick_rate
        start_time = time.time()

        self.curr_goal = None

        # Wait to recive initial params
        while not received_params:
            time.sleep(1.0)
            print("Waiting for initial params...")

        while True:
            # Delay
            new_start_time = time.time()
            elapsed = new_start_time - start_time
            if elapsed < loop_time:
                time.sleep(loop_time - elapsed)
            start_time = time.time()
            self.tick()

    def save(self, obs):
        if obs["action_state_source"] == "DoResetState":
            return

        obs["image"] = tf.convert_to_tensor(obs["image"]) # get it as raw byte array 
        obs["action_state_source"] = tf.convert_to_tensor(obs["action_state_source"]) 

        obs["goal"] = {
            "image": self.curr_goal["img_bytes"],
            "position": self.curr_goal["position"],
            "orientation": self.curr_goal["orientation"],
            "reached": self.task_result["reached_goal"],
            "sample_info": self.curr_goal["sample_info"],
        }
        
        formatted_obs = {
            "observation": obs,
            "action": tf.concat([obs["last_action_linear"], obs["last_action_angular"]], axis = 0),
            "is_first": self.task_result["is_first"],
            "is_last": self.task_result["is_last"],
            "is_terminal": self.task_result["is_terminal"],
        }

        self.local_data_store.insert(formatted_obs)

    def take_action(self, obs):
        if self.curr_obs_image.shape != (64, 64, 3):
            resized_obs_img = np.array(tf.image.resize(self.curr_obs_image, [64, 64]))
        else:
            resized_obs_img = np.array(self.curr_obs_image)
        
        if self.curr_goal["image"].shape != (64, 64, 3):
            resized_goal_img = np.array(tf.image.resize(self.curr_goal["image"], [64, 64]))
        else:
            resized_goal_img = np.array(self.curr_goal["image"])

        assert resized_obs_img.shape == (64, 64, 3), f"actual obs img size {resized_obs_img.shape}"
        assert resized_goal_img.shape == (64, 64, 3), f"actual goal img size {resized_goal_img.shape}"

        action = self.agent.predict(obs_image=resized_obs_img, goal_image=resized_goal_img)
        
        if self.model_config.agent_name == "gc_cql":
            self.get_qs(obs, resized_obs_img, resized_goal_img, action)
        action_scaled = np.array([WAYPOINT_SPACING * (action[0] - X_OFFSET), action[1] * ANGLE_SCALE]) # can tinker with this... 
        self.robot.act("action_vw", action_scaled)

    def get_qs(self, obs, obs_img, goal_img, action):
        goal_vec, proprio = get_relative_position(obs["position"][:2], _yaw(obs["orientation"]), 
                                                self.curr_goal["position"][:2], is_negative= False)
        
        batched_obs = {"image": np.array([obs_img]), "proprio": np.array([proprio])}
        batched_goal = {"image": np.array([goal_img])}
        
        predicted_qs = self.agent.actor.forward_critic(observations = (batched_obs, batched_goal),
                                actions = np.array([action]), 
                                rng = self.agent.rng, 
                                train = False,) 

        formatted_qs = np.array([predicted_qs[0][0], predicted_qs[1][0]])
        if formatted_qs.shape == (2, ):
            self.robot.act("q_vals", formatted_qs)
        
    def reset_goal(self, pos, quat, reached):
        position, orientation = self.task.reset(pos, quat, reached) 
        self.curr_goal = self.task.get_goal()

        if type(self.curr_goal["image"]) == str: # need to decode it! 
            self.curr_goal["image"] = tf.io.decode_image(self.curr_goal["image"], expand_animations = False)

        self.curr_goal["int_image"] = self.curr_goal["image"]
        self.curr_goal["image"] = normalize_image(self.curr_goal["image"])
        self.curr_goal["yaw"] = _yaw(self.curr_goal["orientation"])

        goal_jpeg = BytesIO()
        Image.fromarray(self.curr_goal["int_image"]).save(goal_jpeg, format="JPEG")
        goal_img_bytes_pose = np.frombuffer(goal_jpeg.getvalue(), dtype=np.uint8)
        self.curr_goal["img_bytes"] = tf.constant(goal_jpeg.getvalue(), dtype = tf.string)

        goal_pose = {"position": self.curr_goal["position"].astype(
            float),  "orientation": self.curr_goal["orientation"].astype(float) , "image": goal_img_bytes_pose}

        self.robot.act("new_goal", goal_pose)

    def handle_traj_end(self, obs):
        print(f"\nResetting, reached {self.task_result['reached_goal']}" +
                          f" timeout {self.task_result['timeout']} " +
                          f" crash {obs['crash']}" +
                          f" keepout {obs['keepout']} ")
        
        if self.task_result["timeout"]:
            self.crash_history.append("timeout")
        elif self.task_result["crash"]:
            self.crash_history.append("crash")
        elif self.task_result["reached_goal"]:
            self.crash_history.append("reach")
    
        if self.task_result["timeout"] or self.task_result["crash"]:
            if self.crash_history.count("crash") + self.crash_history.count("keepout") > 8:
                self.robot.act("action_vw", np.array([0.0, 0.0]))
                # Wait for the human to press enter
                input("Crashed too many times, fix the robot and hit ENTER to continue:")
                self.crash_history.clear()
            elif self.task_result["crash"]:
                self.reset(obs)
                
        obs = self.robot.obs()
        while obs is None:
            obs = self.robot.obs()
            
        self.reset_goal(obs["position"], obs["orientation"], self.task_result["reached_goal"])
        self.trainer.request("send-stats", 
                            {"reach": self.task_result["reached_goal"],
                                "timeout": self.task_result["timeout"],
                                "crash": self.task_result["crash"]})
        self.task.reset_timer()
        
    def reset(self, obs):
        
        if obs["keepout"]:
            twists = [np.array([-0.3, 0]), np.array([0.0, -0.5])]
            time_per_twist = [3.0, 1.2]
        elif obs["crash_left"]:
            twists = [np.array([-0.2, 0]), np.array([0.0, -0.5])]
            time_per_twist = [1.0, 1.0]
        elif obs["crash_right"]:
            twists = [np.array([-0.2, 0]), np.array([0.0, 0.5])]
            time_per_twist = [1.0, 1.0]
        elif obs["crash_center"]:
            twists = [np.array([-0.2, 0]), np.random.choice([-1, 1]) * np.array([0.0, 0.5])]
            time_per_twist = [1.0, 1]
       
        res = {"running": True, "reason": "starting"}

        while res is None or res["running"]:
            res = self.robot.act(
                "reset", {"twists": twists, "time_per_twist": time_per_twist})
            time.sleep(0.3)

    def tick(self):
        """
        Runs at self.config.tick_rate
        """
        # GET OBSERVATION 
        obs = self.robot.obs()
        if obs is None:
            return
        
        # MAKE SURE WE HAVE A GOAL
        if self.curr_goal is None: # gotta get that first goal 
            self.reset_goal(obs["position"], obs["orientation"], 0)
            return 

        # GET IMAGE FROM OBSERVATION
        self.curr_obs_image = tf.io.decode_image(obs["image"], expand_animations = False)
        self.curr_obs_image = (tf.cast(self.curr_obs_image, tf.float32) / 255.0 - IMAGENET_MEAN) / IMAGENET_STD # NORMALIZE FOR CONSISTENT INPUTS TO FUNC 
        self.curr_obs_image = np.array(self.curr_obs_image)

        # UPDATE TASK
        self.task_result = self.task.update(obs["position"], obs["orientation"], obs["crash"] or obs["keepout"])

        # SAVE INFO
        self.save(obs)
        
        # HANDLE END 
        if self.task_result["reached_goal"] or self.task_result["timeout"] or self.task_result["crash"]:
            self.handle_traj_end(obs)
        else:
            self.take_action(obs)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train Actor')
    parser.add_argument('--trainer_ip', type=str, help='Where are we training?')
    parser.add_argument('--robot_ip', type=str, help='Where is the robot?')
    parser.add_argument('--goal_dir', type= str, help = "npz to load goals from" )
    parser.add_argument('--seed', type= int, help = "npz to load goals from" )

    args = parser.parse_args()

    data_config = get_config("gnm")
    TrainActor(args.trainer_ip, args.robot_ip, args.goal_dir, args.seed).run()