from __future__ import print_function
import carla
import numpy as np
import torch
import json
import time
from srunner.AdditionTools.scenario_operation import ScenarioOperation
from srunner.scenariomanager.carla_data_provider import CarlaDataProvider
from srunner.scenario_dynamic.basic_scenario_dynamic import BasicScenarioDynamic
from srunner.AdditionTools.scenario_utils import calculate_distance_transforms

from srunner.scenario_dynamic.diffusegen.basic_diffusion_scenario import BasicDiffusionCollectionScenarioDynamic
import srunner.scenario_dynamic.diffusegen.carla_utils as carla_utils
import srunner.scenario_dynamic.diffusegen.diffuser.utils as utils


class OppositeVehicleRunningRedLightDynamic(BasicScenarioDynamic):
    """
    This class holds everything required for a scenario,
    in which an other vehicle takes priority from the ego
    vehicle, by running a red traffic light (while the ego
    vehicle has green)

    This is a single ego vehicle scenario
    """

    def __init__(self, world, ego_vehicles, config, randomize=False, debug_mode=False, criteria_enable=True,
                 timeout=180):
        """
        Setup all relevant parameters and create scenario
        and instantiate scenario manager
        """

        # Timeout of scenario in seconds
        self.timeout = timeout

        self.actor_speed = 10

        super(OppositeVehicleRunningRedLightDynamic, self).__init__("OppositeVehicleRunningRedLightDynamic",
                                                             ego_vehicles,
                                                             config,
                                                             world,
                                                             debug_mode,
                                                             criteria_enable=criteria_enable)

        self._traffic_light = CarlaDataProvider.get_next_traffic_light(self.ego_vehicles[0], False)

        if self._traffic_light is None:
            print("No traffic light for the given location of the ego vehicle found")

        self._traffic_light.set_state(carla.TrafficLightState.Green)
        self._traffic_light.set_green_time(self.timeout)

        self.scenario_operation = ScenarioOperation(self.ego_vehicles, self.other_actors)
        self.reference_actor = None
        self.trigger_distance_threshold = 35
        self.trigger = False
        self._actor_distance = 110
        self.ego_max_driven_distance = 150

        # self.seed = 42
        self.max_waypt = 12
        self.routeplanner = None
        self.waypoints = []
        self.vehicle_front = False
        self.init_model(config.parameters)  # [records, load_path, seed, diffusion_step]
        self.step = 0
        self.obs = None
        self.action = np.array([0, 0]).astype(np.float32)
        self.actions = []
        self.current_obs = None
        self.current_reward = None

    def initialize_actors(self):
        """
        Custom initialization
        """
        config = self.config
        self._other_actor_transform = config.other_actors[0].transform
        first_vehicle_transform = carla.Transform(
            carla.Location(config.other_actors[0].transform.location.x,
                           config.other_actors[0].transform.location.y,
                           config.other_actors[0].transform.location.z),
            config.other_actors[0].transform.rotation)

        self.other_actor_transform.append(first_vehicle_transform)
        self.actor_type_list.append("vehicle.audi.tt")
        self.scenario_operation.initialize_vehicle_actors(self.other_actor_transform, self.other_actors,
                                                          self.actor_type_list)
        self.reference_actor = self.other_actors[0]

        # other vehicle's traffic light
        traffic_light_other = CarlaDataProvider.get_next_traffic_light(self.other_actors[0], False)

        if traffic_light_other is None:
            print("No traffic light for the given location of the other vehicle found")

        traffic_light_other.set_state(carla.TrafficLightState.Red)
        traffic_light_other.set_red_time(self.timeout)

        self.initialize_route_planner()

        adv_actor_id = 0
        self.obs = self.get_obs(adv_actor_id, normalize=True)
        current_random_state = torch.get_rng_state()
        current_cuda_random_state = torch.cuda.get_rng_state()
        torch.manual_seed(self.config.parameters[2])
        actions = utils.colab.run_diffusion(self.diffusion, self.obs, action_dim=self.diffusion.action_dim, horizon=128, select_step=self.config.parameters[3])
        torch.set_rng_state(current_random_state)
        torch.cuda.set_rng_state(current_cuda_random_state)
        for i in range(128):
            self.actions.append(carla_utils.postprocess_action(actions[0, i, :]))

        # collision_bp = self.world.get_blueprint_library().find('sensor.other.collision')
        # self.collision_sensor = self.world.spawn_actor(collision_bp, carla.Transform(), attach_to=self.other_actors[adv_actor_id])
        #
        # def get_collision_hist(event):
        #     impulse = event.normal_impulse
        #     intensity = np.sqrt(impulse.x**2 + impulse.y**2 + impulse.z**2)
        #     self.collision_hist.append(intensity)
        #
        # self.collision_hist = []
        # self.collision_sensor.listen(lambda event: get_collision_hist(event))

    def initialize_route_planner(self):
        carla_map = self.world.get_map()
        forward_vector = self.other_actor_transform[0].rotation.get_forward_vector() * self._actor_distance
        self.target_transform = carla.Transform(carla.Location(self.other_actor_transform[0].location + forward_vector),
                                           self.other_actor_transform[0].rotation)

    def update_behavior(self):
        adv_actor_id = 0
        self.current_obs = self.get_obs(adv_actor_id, normalize=True)
        if self.step % 4 == 0:
            horizon = self.step // 4
            horizon = horizon if horizon < 128 else 127
            self.action = self.actions[horizon]
        self.scenario_operation.apply_control(adv_actor_id, self.action[0], self.action[1])
        self.world.wait_for_tick()
        self.step += 1
        self.current_reward = self.get_current_reward(adv_actor_id)

    def _create_behavior(self):
        pass

    def check_stop_condition(self):
        """
        small scenario stops when actor runs a specific distance
        """
        cur_distance = calculate_distance_transforms(CarlaDataProvider.get_transform(self.other_actors[0]),
                                                     self.other_actor_transform[0])
        if cur_distance >= self._actor_distance:
            return True
        return False

    def init_model(self, parameters):
        driving_records = [parameters[0]]
        # '/home/carla/Evaluation/pkgs/scenario_runner/srunner/scenario_dynamic/diffusegen/carla_driving_record/s242_first100.pkl'
        # diffusion_experiment = utils.load_diffusion(parameters[1], epoch='latest', records=driving_records)
        # diffusion_experiment = utils.load_diffusion(parameters[1], epoch='500000', records=driving_records)
        diffusion_experiment = utils.load_diffusion(parameters[1], epoch='500000')

        self.dataset = diffusion_experiment.dataset
        self.diffusion = diffusion_experiment.trainer.ema_model
        self.diffusion.cuda()

    def get_current_reward(self, adv_actor_id):
        adv_trans = self.other_actors[adv_actor_id].get_transform()
        adv_x = adv_trans.location.x
        adv_y = adv_trans.location.y

        ego_trans = self.ego_vehicles[0].get_transform()
        ego_x = ego_trans.location.x
        ego_y = ego_trans.location.y

        return - np.linalg.norm(np.array([adv_x - ego_x, adv_y - ego_y]))

    def get_obs(self, adv_actor_id, normalize=True):
        adv_trans = self.other_actors[adv_actor_id].get_transform()
        adv_x = adv_trans.location.x
        adv_y = adv_trans.location.y
        adv_yaw = adv_trans.rotation.yaw / 180 * np.pi

        target_vector = carla_utils.make_unit_vector(carla.Vector3D(x=self.target_transform.location.x - adv_x, y=self.target_transform.location.y - adv_y, z=0))
        target_location = carla.Location(adv_trans.location + target_vector * 2)
        target_yaw = np.arcsin(target_vector.y) / np.pi * 180
        if target_vector.x < 0:
            target_yaw = 180 - target_yaw
        target_rotation = carla.Rotation(yaw=target_yaw)
        target_transform = carla.Transform(target_location, target_rotation)
        target_waypoint = [target_transform.location.x, target_transform.location.y, target_transform.rotation.yaw]
        lateral_dis, w = carla_utils.get_preview_lane_dis([target_waypoint]*5, adv_x, adv_y)
        delta_yaw = np.arcsin(
            np.cross(w, np.array(np.array([np.cos(adv_yaw),
                                           np.sin(adv_yaw)]))))
        v = self.other_actors[adv_actor_id].get_velocity()
        speed = np.sqrt(v.x ** 2 + v.y ** 2)
        state = np.array([lateral_dis, -delta_yaw, speed, self.vehicle_front])
        # print('state', state)
        if normalize:
            state[0] = carla_utils.normalize(state[0], -2, 2)
            state[1] = carla_utils.normalize(state[1], -1, 1)
            state[2] = carla_utils.normalize(state[2], -5, 30)
            state[3] = carla_utils.normalize(state[3], 0, 1)

        return state.astype(np.float32)


class SignalizedJunctionLeftTurnDynamic(BasicDiffusionCollectionScenarioDynamic):
    """
    Implementation class for Hero
    Vehicle turning left at signalized junction scenario
    An actor has higher priority, ego needs to yield to
    Oncoming actor
    """

    def __init__(self, world, ego_vehicles, config, randomize=False, debug_mode=False, criteria_enable=True,
                 timeout=80):
        """
            Setup all relevant parameters and create scenario
        """
        self._world = world
        self._map = CarlaDataProvider.get_map()
        self._target_vel = 12.0
        self.timeout = timeout
        # self._brake_value = 0.5
        # self._ego_distance = 110
        self._actor_distance = 100
        self._traffic_light = None
        super(SignalizedJunctionLeftTurnDynamic, self).__init__("TurnLeftAtSignalizedJunctionDynamic",
                                                         ego_vehicles,
                                                         config,
                                                         world,
                                                         debug_mode,
                                                         criteria_enable=criteria_enable)
        self._traffic_light = CarlaDataProvider.get_next_traffic_light(self.ego_vehicles[0], False)
        # traffic_light_other = CarlaDataProvider.get_next_traffic_light(config.other_actors[0], True)
        if self._traffic_light is None:
            raise RuntimeError("No traffic light for the given location found")
        self._traffic_light.set_state(carla.TrafficLightState.Green)
        self._traffic_light.set_green_time(self.timeout)
        # other vehicle's traffic light

        self.scenario_operation = ScenarioOperation(self.ego_vehicles, self.other_actors)
        self.reference_actor = None
        self.trigger_distance_threshold = 45
        self.ego_max_driven_distance = 150

        self.set_init_attr()

    def initialize_actors(self):
        """
        initialize actor
        """
        config = self.config
        first_vehicle_transform = carla.Transform(
            carla.Location(config.other_actors[0].transform.location.x,
                           config.other_actors[0].transform.location.y,
                           config.other_actors[0].transform.location.z),
            config.other_actors[0].transform.rotation)
        self.other_actor_transform.append(first_vehicle_transform)
        # self.actor_type_list.append("vehicle.diamondback.century")
        self.actor_type_list.append("vehicle.audi.tt")
        self.scenario_operation.initialize_vehicle_actors(self.other_actor_transform, self.other_actors, self.actor_type_list)
        self.reference_actor = self.other_actors[0]

        traffic_light_other = CarlaDataProvider.get_next_traffic_light(self.other_actors[0], False)
        if traffic_light_other is None:
            raise RuntimeError("No traffic light for the given location found")
        traffic_light_other.set_state(carla.TrafficLightState.Green)
        traffic_light_other.set_green_time(self.timeout)

        self.initialize_route_planner()

        self.calculate_init_actions()


class SignalizedJunctionRightTurnDynamic(BasicDiffusionCollectionScenarioDynamic):
    """
    Implementation class for Hero
    Vehicle turning right at signalized junction scenario
    An actor has higher priority, ego needs to yield to
    Oncoming actor
    """

    def __init__(self, world, ego_vehicles, config, randomize=False, debug_mode=False, criteria_enable=True,
                 timeout=80):
        """
            Setup all relevant parameters and create scenario
        """
        self._world = world
        self._map = CarlaDataProvider.get_map()
        self._target_vel = 12
        self.timeout = timeout
        # self._brake_value = 0.5
        # self._ego_distance = 110
        self._actor_distance = 100
        self._traffic_light = None
        super(SignalizedJunctionRightTurnDynamic, self).__init__("TurnRightAtSignalizedJunctionDynamic",
                                                         ego_vehicles,
                                                         config,
                                                         world,
                                                         debug_mode,
                                                         criteria_enable=criteria_enable)
        self._traffic_light = CarlaDataProvider.get_next_traffic_light(self.ego_vehicles[0], False)
        # traffic_light_other = CarlaDataProvider.get_next_traffic_light(config.other_actors[0], True)
        if self._traffic_light is None:
            raise RuntimeError("No traffic light for the given location found")
        self._traffic_light.set_state(carla.TrafficLightState.Red)
        self._traffic_light.set_green_time(self.timeout)
        # other vehicle's traffic light

        self.scenario_operation = ScenarioOperation(self.ego_vehicles, self.other_actors)
        self.reference_actor = None
        self.trigger_distance_threshold = 35
        self.trigger = False
        self.ego_max_driven_distance = 150

        self.set_init_attr()

    def initialize_actors(self):
        """
        initialize actor
        """
        config = self.config
        first_vehicle_transform = carla.Transform(
            carla.Location(config.other_actors[0].transform.location.x,
                           config.other_actors[0].transform.location.y,
                           config.other_actors[0].transform.location.z),
            config.other_actors[0].transform.rotation)
        self.other_actor_transform.append(first_vehicle_transform)
        self.actor_type_list.append("vehicle.audi.tt")
        self.scenario_operation.initialize_vehicle_actors(self.other_actor_transform, self.other_actors, self.actor_type_list)
        self.reference_actor = self.other_actors[0]

        traffic_light_other = CarlaDataProvider.get_next_traffic_light(self.other_actors[0], False)
        if traffic_light_other is None:
            raise RuntimeError("No traffic light for the given location found")
        traffic_light_other.set_state(carla.TrafficLightState.Green)
        traffic_light_other.set_green_time(self.timeout)

        self.initialize_route_planner()

        self.calculate_init_actions()


class NoSignalJunctionCrossingRouteDynamic(BasicDiffusionCollectionScenarioDynamic):
    """

    """

    def __init__(self, world, ego_vehicles, config, randomize=False, debug_mode=False, criteria_enable=True,
                 timeout=60):
        """
        Setup all relevant parameters and create scenario
        """
        # Timeout of scenario in seconds
        self.timeout = timeout

        self.actor_speed = 10

        super(NoSignalJunctionCrossingRouteDynamic, self).__init__("NoSignalJunctionCrossing",
                                                       ego_vehicles,
                                                       config,
                                                       world,
                                                       debug_mode,
                                                       criteria_enable=criteria_enable)
        self.scenario_operation = ScenarioOperation(self.ego_vehicles, self.other_actors)
        self.reference_actor = None
        self.trigger_distance_threshold = 35
        self.trigger = False

        self._actor_distance = 110
        self.ego_max_driven_distance = 150

        self.set_init_attr()

    def initialize_actors(self):
        config = self.config
        self._other_actor_transform = config.other_actors[0].transform
        first_vehicle_transform = carla.Transform(
            carla.Location(config.other_actors[0].transform.location.x,
                           config.other_actors[0].transform.location.y,
                           config.other_actors[0].transform.location.z),
            config.other_actors[0].transform.rotation)

        self.other_actor_transform.append(first_vehicle_transform)
        self.actor_type_list.append("vehicle.audi.tt")
        self.scenario_operation.initialize_vehicle_actors(self.other_actor_transform, self.other_actors,
                                                          self.actor_type_list)
        self.reference_actor = self.other_actors[0]

        self.initialize_route_planner()

        self.calculate_init_actions()


