from __future__ import print_function
import carla
import numpy as np
import torch
import json
import time
import random
from srunner.AdditionTools.scenario_operation import ScenarioOperation
from srunner.scenariomanager.carla_data_provider import CarlaDataProvider
from srunner.scenario_dynamic.basic_scenario_dynamic import BasicScenarioDynamic
from srunner.AdditionTools.scenario_utils import calculate_distance_transforms

import srunner.scenario_dynamic.diffusegen.carla_utils as carla_utils
import srunner.scenario_dynamic.diffusegen.diffuser.utils as utils
import srunner.scenario_dynamic.diffusegen.diffuser.sampling as sampling
from srunner.scenario_dynamic.diffusegen.diffuser.models.diffusion import default_sample_fn


class BasicDiffusionScenarioDynamic(BasicScenarioDynamic):
    def set_init_attr(self):
        # self.seed = 42
        self.max_waypt = 12
        self.routeplanner = None
        self.waypoints = []
        self.vehicle_front = False
        self.sample_fn = sampling.n_step_guided_p_sample if self.config.parameters[2] else default_sample_fn
        self.init_model()  # [diffusion_load_dir, value_load_dir, guide, seed, step, noise]
        self.step = 0
        self.obs = None
        self.action = np.array([0, 0]).astype(np.float32)
        self.actions = []

    def initialize_route_planner(self):
        carla_map = self.world.get_map()
        forward_vector = self.other_actor_transform[0].rotation.get_forward_vector() * self._actor_distance
        self.target_transform = carla.Transform(carla.Location(self.config.other_actors[0].transform.location + forward_vector),
                                           self.config.other_actors[0].transform.rotation)

    def _create_behavior(self):
        pass

    def check_stop_condition(self):
        """
        small scenario stops when actor runs a specific distance
        """
        cur_distance = calculate_distance_transforms(CarlaDataProvider.get_transform(self.other_actors[0]), self.other_actor_transform[0])
        if cur_distance >= self._actor_distance:
            return True
        return False

    def init_model(self):
        diffusion_experiment = utils.load_diffusion('/', 'carla', self.config.parameters[0], epoch='500000')
        # value_experiment = utils.load_diffusion('/', 'carla', self.config.parameters[1], epoch='latest')
        value_experiment = utils.load_diffusion('/', 'carla', self.config.parameters[1], epoch='98000')
        # value_experiment = utils.load_diffusion('/', 'carla', self.config.parameters[1], epoch='16000')
        utils.check_compatibility(diffusion_experiment, value_experiment)
        diffusion = diffusion_experiment.ema
        diffusion.cuda()
        diffusion.eval()
        value_function = value_experiment.ema
        value_function.cuda()
        value_function.eval()
        guide_config = utils.Config('scenario_dynamic.diffusegen.diffuser.sampling.ValueGuide', model=value_function, verbose=False)
        guide = guide_config()
        guide.cuda()
        guide.eval()

        policy_config = utils.Config(
            'scenario_dynamic.diffusegen.diffuser.sampling.GuidedPolicy',
            guide=guide,
            # scale=1,
            scale=0.1,
            diffusion_model=diffusion,
            normalizer=None,
            preprocess_fns=[],
            ## sampling kwargs
            sample_fn=self.sample_fn,
            n_guide_steps=self.config.parameters[4],
            t_stopgrad=2,
            scale_grad_by_std=True,
            verbose=False,
            horizon=128,
        )

        self.policy = policy_config()

    def calculate_init_actions(self):
        adv_actor_id = 0
        self.obs = self.get_obs(adv_actor_id, normalize=True)
        current_random_state = torch.get_rng_state()
        current_cuda_random_state = torch.cuda.get_rng_state()
        torch.manual_seed(self.config.parameters[3])
        # actions = utils.colab.run_diffusion(self.diffusion, self.obs, action_dim=self.dataset.action_dim, horizon=128)
        _, samples = self.policy({0: self.obs}, verbose=False)
        actions = samples.actions
        torch.set_rng_state(current_random_state)
        torch.cuda.set_rng_state(current_cuda_random_state)
        for i in range(128):
            self.actions.append(carla_utils.postprocess_action(actions[0, i, :]))

    def update_behavior(self):
        adv_actor_id = 0
        if self.step % 4 == 0:
            horizon = self.step // 4
            horizon = horizon if horizon < 128 else 127
            self.action = self.actions[horizon]
        self.scenario_operation.apply_control(adv_actor_id, self.action[0], self.action[1])
        self.world.wait_for_tick()
        self.step += 1
        # print('step: {}, time: {}'.format(self.step, time.time() - start_time))

    def get_obs(self, adv_actor_id, normalize=True):
        adv_trans = self.other_actors[adv_actor_id].get_transform()
        adv_x = adv_trans.location.x
        adv_y = adv_trans.location.y
        adv_yaw = adv_trans.rotation.yaw / 180 * np.pi

        target_vector = carla_utils.make_unit_vector(carla.Vector3D(x=self.target_transform.location.x - adv_x, y=self.target_transform.location.y - adv_y, z=0))
        target_location = carla.Location(adv_trans.location + target_vector * 2)
        target_yaw = np.arcsin(target_vector.y) / np.pi * 180
        if target_vector.x < 0:
            target_yaw = 180 - target_yaw
        target_rotation = carla.Rotation(yaw=target_yaw)
        target_transform = carla.Transform(target_location, target_rotation)
        target_waypoint = [target_transform.location.x, target_transform.location.y, target_transform.rotation.yaw]
        lateral_dis, w = carla_utils.get_preview_lane_dis([target_waypoint]*5, adv_x, adv_y)
        delta_yaw = np.arcsin(
            np.cross(w, np.array(np.array([np.cos(adv_yaw),
                                           np.sin(adv_yaw)]))))
        v = self.other_actors[adv_actor_id].get_velocity()
        speed = np.sqrt(v.x ** 2 + v.y ** 2)
        state = np.array([lateral_dis, -delta_yaw, speed, self.vehicle_front])
        # print('state', state)
        if normalize:
            state[0] = carla_utils.normalize(state[0], -2, 2)
            state[1] = carla_utils.normalize(state[1], -1, 1)
            state[2] = carla_utils.normalize(state[2], -5, 30)
            state[3] = carla_utils.normalize(state[3], 0, 1)

        return state.astype(np.float32)


class BasicDiffusionCollectionScenarioDynamic(BasicDiffusionScenarioDynamic):
    def set_init_attr(self):
        # self.seed = 42
        self.max_waypt = 12
        self.routeplanner = None
        self.waypoints = []
        self.vehicle_front = False
        self.init_model()  # [records, load_path, seed, diffusion_step]
        self.step = 0
        self.obs = None
        self.action = np.array([0, 0]).astype(np.float32)
        self.actions = []
        self.current_obs = None
        self.current_reward = None

    def init_model(self):
        driving_records = [self.config.parameters[0]]
        diffusion_experiment = utils.load_diffusion(self.config.parameters[1], epoch='500000')

        self.dataset = diffusion_experiment.dataset
        self.diffusion = diffusion_experiment.trainer.ema_model
        self.diffusion.cuda()

    def calculate_init_actions(self):
        adv_actor_id = 0
        self.obs = self.get_obs(adv_actor_id, normalize=True)
        current_random_state = torch.get_rng_state()
        current_cuda_random_state = torch.cuda.get_rng_state()
        torch.manual_seed(self.config.parameters[2])
        actions = utils.colab.run_diffusion(self.diffusion, self.obs, action_dim=self.diffusion.action_dim, horizon=128, select_step=self.config.parameters[3])
        torch.set_rng_state(current_random_state)
        torch.cuda.set_rng_state(current_cuda_random_state)
        for i in range(128):
            self.actions.append(carla_utils.postprocess_action(actions[0, i, :]))

    def update_behavior(self):
        adv_actor_id = 0
        self.current_obs = self.get_obs(adv_actor_id, normalize=True)
        if self.step % 4 == 0:
            horizon = self.step // 4
            horizon = horizon if horizon < 128 else 127
            self.action = self.actions[horizon]
        self.scenario_operation.apply_control(adv_actor_id, self.action[0], self.action[1])
        self.world.wait_for_tick()
        self.step += 1
        self.current_reward = self.get_current_reward(adv_actor_id)

    def get_current_reward(self, adv_actor_id):
        adv_trans = self.other_actors[adv_actor_id].get_transform()
        adv_x = adv_trans.location.x
        adv_y = adv_trans.location.y

        ego_trans = self.ego_vehicles[0].get_transform()
        ego_x = ego_trans.location.x
        ego_y = ego_trans.location.y

        return - np.linalg.norm(np.array([adv_x - ego_x, adv_y - ego_y]))
