import copy
import time
from collections import deque, namedtuple
from enum import Enum

import carla
import numpy as np
import torch
from ding.envs import BaseEnvTimestep
from ding.torch_utils import to_ndarray, to_tensor

from models import ImplicitSupervisedModel
from envs.di_drive.core.envs.base_carla_env import BaseCarlaEnv
from envs.di_drive.core.simulators import CarlaSimulator
from utils import adapt_order, compute_angle, compute_point_line_dis, compute_cirle, compute_speed

Orders = Enum("Order", "Follow_Lane Straight Right Left ChangelaneLeft ChangelaneRight")


class ImplicitCarlaEnv(BaseCarlaEnv):
    '''
    The Carla environment for RL training
    '''
    def __init__(self, args):
        self.args = args
        self.port = args.port
        self.host = args.host
        self._simulator_cfg = self.args.simulator
        self.weather_list = self.args.weather_list
        self.town = self.args.town

        self.n_vehicles = self.args.n_vehicles
        self.n_pedestrians = self.args.n_pedestrians
        self._col_is_failure = True
        self.steps_image = [-10, -2, -1, 0]
        self._simulator = None
        self._launched_simulator = False
        self._init_carla_simulator()

        self.steps = 0
        self._tick = 0
        self._launched_simulator = False
        self.stop_steps = 0
        self.seed()

        if self.args.crop_sky:
            for obs_item in self._simulator_cfg.obs:
                obs_item.update({'fov': 100})
            model_supervised = ImplicitSupervisedModel(
                len(self.steps_image), len(self.steps_image), 1024, 6, 4, self.args.crop_sky
            )
            model_supervised.load_state_dict(torch.load(self.args.supervised_model_path))
        else:
            for obs_item in self._simulator_cfg.obs:
                obs_item.update({'fov': 90})
            model_supervised = ImplicitSupervisedModel(
                len(self.steps_image), len(self.steps_image), 1024, 3, 4, self.args.crop_sky
            )
            model_supervised.load_state_dict(torch.load(self.args.supervised_model_path))
        model_supervised.eval()

        self.encoder = model_supervised.encoder.cuda()
        self.last_conv_downsample = model_supervised.last_conv_downsample.cuda()

        self.action_space = (self.args.nb_action_throttle + 1) * self.args.nb_action_steering

        self.window = (max([abs(number) for number in self.steps_image]) + 1)  # Number of frames to concatenate
        self.State = namedtuple("State", ("image", "speed", "order", "steering"))

        self.RGB_image_buffer = deque([], maxlen=self.window)
        self.state_buffer = deque([], maxlen=self.window)

        if self.args.crop_sky:
            blank_state = self.State(
                np.zeros(6144, dtype=np.float32), -1, -1, 0
            )  # RGB Image, color channet first for torch
        else:
            blank_state = self.State(np.zeros(8192, dtype=np.float32), -1, -1, 0)
        for _ in range(self.window):
            self.state_buffer.append(blank_state)
            if self.args.crop_sky:
                self.RGB_image_buffer.append(
                    np.zeros((3, self.args.front_camera_height - 120, self.args.front_camera_width))
                )
            else:
                self.RGB_image_buffer.append(np.zeros((3, self.args.front_camera_height, self.args.front_camera_width)))

        self.last_steering = 0
        self.last_order = 0

    def _init_carla_simulator(self):
        print(self.host, self.port)
        self._simulator = CarlaSimulator(cfg=self._simulator_cfg, client=None, host=self.host, port=self.port)
        self._launched_simulator = True

    def seed(self, seed=None):
        np.random.seed(seed)

    def _action_to_control(self, action):
        steer = ((action % self.args.nb_action_steering) - int(self.args.nb_action_steering / 2)
                 ) * (self.args.max_steering / int(self.args.nb_action_steering / 2))
        if action < int(self.args.nb_action_steering * self.args.nb_action_throttle):
            throttle = (int(action / self.args.nb_action_steering
                            )) * (self.args.max_throttle / (self.args.nb_action_throttle - 1))
            brake = 0
        else:
            throttle = 0
            brake = 1.0

        steer = steer
        throttle = throttle
        if brake < 1.0 / 2:
            brake = 0

        control = {}
        control['steer'] = np.clip(steer, -1.0, 1.0)
        control['throttle'] = np.clip(throttle, 0.0, 1.0)
        control['brake'] = np.clip(brake, 0.0, 1.0)
        self.last_steering = steer
        return control

    def step(self, action):
        action = action.item()
        control = self._action_to_control(action)
        self._simulator.apply_control(control)
        self._simulator.run_step()
        self.new_obs = self._get_obs()

        reward, reward_info, distance, forward_vector, velocity = self.compute_reward(self.new_obs, control, self.steps)

        done = self.is_success() or self.is_failure()
        if velocity < 0.2 and (self.new_obs['tl_state'] > 1.5 or self.new_obs['tl_dis'] > 23):
            self.stop_steps += 1
        if velocity > 0.3:
            self.stop_steps = max(0, self.stop_steps - 100)

        if self.new_obs['is_junction'] and velocity > 3.0:
            reward = reward + 0.1

        if reward_info['lane_reward'] < -0.99 or reward_info['angle_reward'] < -0.99:
            done = True
            reward = reward - 1
            print("Step: %d, Terminate! (Offlane)" % self.steps)

        if self.stop_steps > 50 and self.new_obs['tl_state'] > 1.5 and control['throttle'] < 0.1:
            done = True
            reward = reward - 1
            print("Stop: %d, Terminate (Stop)!" % self.steps)

        if self.is_success():
            print("Step: %d, Success !!" % self.steps)

        if self.is_failure():
            print("Step: %d, failure" % self.steps)
            reward = reward - 1

        if self.steps % 20 == 0:
            print(
                "Step: %d, reward: %f, distance: %f, speed: %f, forward: (%f, %f)" %
                (self.steps, reward, distance, velocity, forward_vector.x, forward_vector.y)
            )
            print("Steer: %f, Throttle: %f, Brake: %f" % (control['steer'], control['throttle'], control['brake']))
            strs = []
            for key in reward_info:
                strs.append("%s: %f " % (key, reward_info[key]))
            reward_info_str = ','.join(strs)
            print(reward_info_str)
            print('-' * 80)

        self.steps += 1
        self._tick += 1
        self._final_eval_reward += reward
        info = {}
        if done:
            info['final_eval_reward'] = self._final_eval_reward
        reward = to_tensor([reward]).float()
        obs = {}
        for key in ['speed', 'steer', 'image', 'targets', 'order']:
            obs[key] = self.new_obs[key]
        obs = to_ndarray(obs)
        return BaseEnvTimestep(obs, reward, done, info)

    def info(self):
        return {}

    def __repr__(self):
        return "Carla Env"

    def _get_obs(self):
        observations = self._simulator.get_sensor_data()
        location = self._simulator.hero_player.get_location()
        waypoint = self._simulator._map.get_waypoint(location, project_to_road=True, lane_type=carla.LaneType.Driving)
        transform = self._simulator.hero_player.get_transform()
        forward_vector = transform.rotation.get_forward_vector()
        state = self._simulator.get_state()
        observations['tl_dis'] = np.float32(state['tl_dis'])
        observations['tl_state'] = state['tl_state']

        right_vec = [forward_vector.y, -forward_vector.x]

        navigation_buffer = self._simulator.get_navigation()
        for key in navigation_buffer:
            observations[key] = navigation_buffer[key]

        observations['is_junction'] = waypoint.is_junction

        node = observations['node']
        next = observations['target']
        node_forward = observations['node_forward']
        next_forward = observations['target_forward']

        targets = []
        targets.append(np.sqrt((location.x - node[0])**2 + (location.y - node[1])**2))
        targets.append(np.sqrt((location.x - next[0])**2 + (location.y - next[1])**2))

        node_sign = np.sign(node_forward[0] * right_vec[0] + node_forward[1] * right_vec[1])
        next_sign = np.sign(next_forward[0] * right_vec[0] + next_forward[1] * right_vec[1])
        targets.append(node_sign * compute_angle([forward_vector.x, forward_vector.y], node_forward))
        targets.append(next_sign * compute_angle([forward_vector.x, forward_vector.y], next_forward))

        rgb = observations["rgb"].copy()
        if self.args.crop_sky:
            rgb = np.array(rgb)[120:, :, :]
        else:
            rgb = np.array(rgb)

        rgb = np.rollaxis(rgb, 2, 0)
        self.RGB_image_buffer.append(rgb)

        # speed = np.linalg.norm(observations["velocity"])
        speed = compute_speed(self._simulator.hero_player.get_velocity())

        order = adapt_order(int(observations["command"]))
        if self.last_order != order:
            print("order = ", Orders(order).name)
            self.last_order = order

        np_array_RGB_input = np.concatenate(
            [self.RGB_image_buffer[indice_image + self.window - 1] for indice_image in self.steps_image]
        )
        torch_tensor_input = (
            torch.from_numpy(np_array_RGB_input).to(dtype=torch.float32).div_(255).unsqueeze(0).cuda()
        )

        with torch.no_grad():
            current_encoding = self.encoder(torch_tensor_input)
            current_encoding = self.last_conv_downsample(current_encoding)

        current_encoding_np = current_encoding.cpu().numpy().flatten()

        current_state = self.State(current_encoding_np, speed, order, self.last_steering)
        self.state_buffer.append(current_state)

        speeds = []
        steerings = []
        for step_image in [-10, -2, -1, 0]:
            state = self.state_buffer[step_image + 10]
            speeds.append(state.speed)
            steerings.append(state.steering)

        observations['command'] = int(observations['command'])
        observations['speed'] = np.array(speeds).astype(np.float32)
        observations['targets'] = np.array(targets).astype(np.float32)
        observations['steer'] = np.array(steerings).astype(np.float32)
        observations['image'] = current_encoding_np.astype(np.float32)
        observations['order'] = current_state.order
        observations = copy.deepcopy(observations)
        del observations['waypoint_list']
        return observations

    def get_changelane_target_waypoint(self, next_loc, next_ori):
        target_loc = []
        target_loc.append(next_loc[0] + next_ori[0] * 10)
        target_loc.append(next_loc[1] + next_ori[1] * 10)
        return target_loc

    def compute_reward(self, ob, control, steps):
        # retrieve information
        location = self._simulator.hero_player.get_location()
        velocity = compute_speed(self._simulator.hero_player.get_velocity())
        acceleration = self._simulator.hero_player.get_acceleration()
        transform = self._simulator.hero_player.get_transform()

        distance = self._simulator._planner.distance_to_goal
        # Rotation: (pitch, yaw, roll) (Y-rotation,Z-rotation,X-rotation)
        forward_vector = transform.rotation.get_forward_vector()
        carla_map = self._simulator._map

        waypoint = carla_map.get_waypoint(location, project_to_road=True, lane_type=carla.LaneType.Driving)

        test_waypoint = carla_map.get_waypoint(location, project_to_road=True)
        if test_waypoint is None:
            hero_out_lane = True
        else:
            hero_out_lane = False

        # speed reward
        # command {1: 'Left', 2: 'Right', 3: 'Straight', 4: 'Follow'}
        # speed is in m/s in carla
        command = ob['command']

        tl_state = ob['tl_state']
        tl_dis = ob['tl_dis']
        is_junction = ob['is_junction']

        node = ob['node'].astype(np.float64)
        next = ob['target'].astype(np.float64)
        node_forward = ob['node_forward'].astype(np.float64)
        next_forward = ob['target_forward'].astype(np.float64)

        node_location = carla.Location(node[0], node[1], 0)
        node_forward_vector = carla.Vector3D(x=node_forward[0], y=node_forward[1], z=0)

        if is_junction:
            target_speed = 5
        elif tl_state == 2 or tl_state == 3:
            target_speed = 5
        elif tl_dis > 23:
            target_speed = 5
        elif tl_dis < 8:
            target_speed = 0
        else:
            target_speed = (tl_dis - 8) / 3

        speed_reward = max(0, 1 - 0.2 * np.abs(velocity - target_speed))

        k1 = -node_forward[0] / (np.sign(node_forward[1]) * max(abs(node_forward[1]), 1e-3))
        k2 = -next_forward[0] / (np.sign(next_forward[1]) * max(abs(next_forward[1]), 1e-3))

        # Judge whether the vehicle is turning. If so, we will calculate the reward in another way
        if command == 5 or command == 6:
            changelane_flag = True
            new2_flag = False
        if abs(node[0] - next[0]) > 1e-1 and abs(node[1] - next[1]) > 1e-1 and np.abs(k1 - k2) > 1e-1:
            circle_flag = True
            changelane_flag = False
            cx, cy, r = compute_cirle(node, node_forward, next, next_forward)
        else:
            circle_flag = False
            changelane_flag = False

        # angle reward
        if hero_out_lane:
            angle_reward = -1
            # lane distance reward
            # only choose corners and center to calcuate
            lane_reward = -1
        else:
            lane_forward_vector = waypoint.transform.rotation.get_forward_vector()
            waypoint_location = waypoint.transform.location
            # the Carla simulator exists some bug at the beginning, so we use the easy method to compute reward
            if steps < 30:
                lane_angle = compute_angle(forward_vector, lane_forward_vector)
            else:
                if circle_flag:
                    angle1 = compute_angle(
                        forward_vector, carla.Vector3D(x=(cy - location.y), y=(location.x - cx), z=0)
                    )
                    angle2 = compute_angle(
                        forward_vector, carla.Vector3D(x=(location.y - cy), y=(cx - location.x), z=0)
                    )
                    lane_angle = min(angle1, angle2)
                elif changelane_flag:
                    if command == 5:
                        target_loc = self.get_changelane_target_waypoint(next, next_forward)
                    if command == 6:
                        target_loc = self.get_changelane_target_waypoint(next, next_forward)
                else:
                    lane_angle = compute_angle(forward_vector, node_forward_vector)
            if lane_angle < np.pi / 2:
                lane_angle = lane_angle * 2 / np.pi
                angle_reward = (1 - lane_angle) / (1 + lane_angle) - 1
                angle_reward = max(min(angle_reward, 0), -1)
            else:
                angle_reward = -1

            if steps < 30:
                hero_lane_distance = location.distance(waypoint_location)
            else:
                if circle_flag:
                    hero_lane_distance = abs(np.sqrt((cx - location.x)**2 + (cy - location.y)**2) - r)
                elif changelane_flag:
                    if command == 5:
                        target_loc = self.get_changelane_target_waypoint(next, next_forward)
                    if command == 6:
                        target_loc = self.get_changelane_target_waypoint(next, next_forward)
                    hero_lane_distance = compute_point_line_dis(
                        node_location.x, node_location.y, target_loc[0] - node[0], target_loc[1] - node[1], location.x,
                        location.y
                    )
                else:
                    hero_lane_distance = compute_point_line_dis(
                        node_location.x, node_location.y, node_forward_vector.x, node_forward_vector.y, location.x,
                        location.y
                    )
            hero_lane_distance = max(min(hero_lane_distance, 2), 0)
            hero_lane_distance = hero_lane_distance / 2.0
            lane_reward = (1 - hero_lane_distance) / (1 + hero_lane_distance) - 1

        total_reward = speed_reward + angle_reward + lane_reward

        info = {}
        info['speed_reward'] = speed_reward
        info['angle_reward'] = angle_reward
        info['lane_reward'] = lane_reward

        return total_reward, info, distance, forward_vector, velocity

    def reset(self):
        self.steps = 0
        self.stop_steps = 0
        self._tick = 0

        self.RGB_image_buffer = deque([], maxlen=self.window)
        self.state_buffer = deque([], maxlen=self.window)

        if self.args.crop_sky:
            blank_state = self.State(
                np.zeros(6144, dtype=np.float32), -1, -1, 0
            )  # RGB Image, color channet first for torch
        else:
            blank_state = self.State(np.zeros(8192, dtype=np.float32), -1, -1, 0)
        for _ in range(self.window):
            self.state_buffer.append(blank_state)
            if self.args.crop_sky:
                self.RGB_image_buffer.append(
                    np.zeros((3, self.args.front_camera_height - 120, self.args.front_camera_width))
                )
            else:
                self.RGB_image_buffer.append(np.zeros((3, self.args.front_camera_height, self.args.front_camera_width)))
        self.last_steering = 0
        self.last_order = 0
        np.random.seed(int(time.time()))

        start = np.random.randint(240)
        target = np.random.randint(240)
        while abs(target - start) < 5:
            target = np.random.randint(240)

        env_params = {
            'weather': np.random.choice(self.weather_list),
            'start': start,
            'target': target,
            'town': self.town,
            'n_pedestrians': self.n_pedestrians,
            'n_vehicles': self.n_vehicles,
        }

        self._simulator.init(**env_params)
        self._simulator.run_step()
        self._success_dis = 5.0
        self._final_eval_reward = 0.
        self._timeout = self._simulator._planner.timeout
        self.new_obs = self._get_obs()
        obs = {}
        for key in ['speed', 'steer', 'image', 'targets', 'order']:
            obs[key] = self.new_obs[key]
        obs = to_ndarray(obs)
        return obs

    def is_failure(self):
        if self._tick >= self._timeout:
            return True
        elif self._col_is_failure and self._simulator.collided:
            return True
        return False

    def is_success(self):
        return self._simulator._planner.distance_to_goal < self._success_dis

    def close(self):
        if self._launched_simulator:
            self._simulator.__exit__()
            del self._simulator
            self._launched_simulator = False
