import numpy as np
import os
import joblib
import gym
from tqdm import tqdm
import traceback
from easydict import EasyDict as edict
import math
from copy import deepcopy
from srunner.scenariomanager.scenarioatomics.atomic_criteria import Status

CARLA_CFG = edict(
    ACC_MAX=3,
    STEERING_MAX=0.3,
    OBS_TYPE=0,
    MAX_EPISODE_LEN=300,
    FRAME_SKIP=4,
)


class DrivingRecord:
    def __init__(self, path, file_name='driving_records.pkl'):
        os.makedirs(path, exist_ok=True)
        self.filename = os.path.join(path, file_name)
        if os.path.exists(self.filename):
            self.trajectories = joblib.load(self.filename)
        else:
            self.trajectories = []
        self.current_trajectory = []

    def add(self, data_dict):
        self.current_trajectory.append(data_dict)

    def on_episode_end(self, stats):
        self.current_trajectory.append(stats)
        self.trajectories.append(self.current_trajectory)
        self.current_trajectory = []
        # joblib.dump(self.trajectories, self.filename)

    def dump(self):
        joblib.dump(self.trajectories, self.filename)


class SacEnv:
    def __init__(self):
        # 4 state space
        obs_dim = 4
        # assume the obs range from -1 to 1
        obs_lim = np.ones((obs_dim), dtype=np.float32)
        self.observation_space = gym.spaces.Box(-obs_lim,
                                                obs_lim,
                                                dtype=np.dtype)
        # build action space, assume the obs range from -1 to 1
        act_dim = 2
        act_lim = np.ones((act_dim), dtype=np.float32)
        self.action_space = gym.spaces.Box(-act_lim, act_lim, dtype=np.float32)


def normalize(x, min_value, max_value):
    # [ 0, 1 ]
    x = (x - min_value) / (max_value - min_value)
    # [ -1, 1 ]
    x = 2 * x - 1
    return x


def unnormalize(x, min_value, max_value, eps=1e-4):
    # [ -1, 1 ] --> [ 0, 1 ]
    x = (x + 1) / 2.

    return x * (max_value - min_value) + min_value


def postprocess_action(action):
    # normalize and clip the action
    action = action * np.array([CARLA_CFG.ACC_MAX, CARLA_CFG.STEERING_MAX])
    action[0] = max(min(CARLA_CFG.ACC_MAX, action[0]), -CARLA_CFG.ACC_MAX)
    action[1] = max(min(CARLA_CFG.STEERING_MAX, action[1]),
                    -CARLA_CFG.STEERING_MAX)
    return action


def get_preview_lane_dis(waypoints, x, y, idx=2):
    """
    Calculate distance from (x, y) to a certain waypoint
    :param waypoints: a list of list storing waypoints like [[x0, y0], [x1, y1], ...]
    :param x: x position of vehicle
    :param y: y position of vehicle
    :param idx: index of the waypoint to which the distance is calculated
    :return: a tuple of the distance and the waypoint orientation
    """
    waypt = waypoints[idx]
    vec = np.array([x - waypt[0], y - waypt[1]])
    lv = np.linalg.norm(np.array(vec))
    w = np.array([np.cos(waypt[2] / 180 * np.pi), np.sin(waypt[2] / 180 * np.pi)])
    cross = np.cross(w, vec / lv)
    dis = - lv * cross
    return dis, w


def make_unit_vector(vector):
    length = np.linalg.norm(np.array([vector.x, vector.y]))
    return vector / length


def cal_out_of_road_length(sequence):
    out_of_road_raw = [i['off_road'] for i in sequence]
    out_of_road = deepcopy(out_of_road_raw)
    for i, out in enumerate(out_of_road_raw):
        if out and i + 1 < len(out_of_road_raw):
            out_of_road[i + 1] = True

    total_length = 0
    for i, out in enumerate(out_of_road):
        if i == 0:
            continue
        if out:
            total_length += sequence[i]['driven_distance'] - sequence[i - 1]['driven_distance']

    return total_length


def cal_avg_yaw_velocity(sequence):
    total_yaw_change = 0
    for i, time_stamp in enumerate(sequence):
        if i == 0:
            continue
        total_yaw_change += abs(sequence[i]['ego_yaw'] - sequence[i - 1]['ego_yaw'])
    total_yaw_change = total_yaw_change / 180 * math.pi
    # avg_yaw_velocity = total_yaw_change / (sequence[-1]['current_game_time'] - sequence[0]['current_game_time'])

    # return avg_yaw_velocity
    return total_yaw_change


def get_statistics_of_current_trajectory(sequence):
    avg_distance_to_route = 0
    for time_stamp in sequence:
        avg_distance_to_route += time_stamp['distance_to_route']
    avg_distance_to_route = avg_distance_to_route / len(sequence)

    avg_acc = 0
    for time_stamp in sequence:
        avg_acc += math.sqrt(time_stamp['ego_acceleration_x'] ** 2 + time_stamp['ego_acceleration_y'] ** 2 + time_stamp['ego_acceleration_z'] ** 2)
    avg_acc = avg_acc / len(sequence)

    scores = {
        # safety level
        'collision': 1 if sequence[-1]['collision'] == Status.FAILURE else 0,  # 0/1
        'run_red_light': sequence[-1]['run_red_light'],  # int
        'run_stop_sign': sequence[-1]['run_stop'],  # int
        'out_of_road_length': cal_out_of_road_length(sequence),  # length

        # task performance level
        'avg_distance_to_route': avg_distance_to_route,
        'route_completion': sequence[-1]['route_complete'] / 100,  # [0, 1]
        'time_spent': sequence[-1]['current_game_time'] - sequence[0]['current_game_time'],

        # comfort level
        'avg_acceleration': avg_acc,
        'avg_yaw_velocity': cal_avg_yaw_velocity(sequence),
        'lane_invasion': sequence[-1]['lane_invasion'],  # int

        # additional info
        'route_length': sequence[-1]['driven_distance'] / sequence[-1]['route_complete'] * 100,  # length
    }

    return scores


def process_data():
    file_name = '/home/carla/output/testing_records/record.pkl'
    dataset = DrivingRecord('/home/carla/output/testing_records', 's5_driving_raw_data.pkl')
    trajectories = joblib.load(file_name)

    for data_id, record_list in tqdm(trajectories.items()):
        current_trajectory = []
        current_actions = None
        for timestep, record in enumerate(record_list):
            if record['adv_obs'] is None:
                continue
            next_timestep = timestep + 1 if timestep + 1 < len(record_list) else timestep
            obs2 = record_list[next_timestep]['adv_obs'] if record_list[next_timestep]['adv_obs'] is not None else record['adv_obs']
            current_trajectory.append({'obs': record['adv_obs'],
                                        'act': np.squeeze(record['adv_action']),
                                        'rew': record['adv_reward'],
                                        'obs2': obs2,
                                        'done': record['adv_done'],
                                        'cost': record['adv_cost']})
            if 'adv_actions' in record:
                if current_actions is None:
                    current_actions = record['adv_actions']
                else:
                    equal = [(current_actions[i] == record['adv_actions'][i]).all() for i in range(len(current_actions))]
                    assert sum(equal) == len(equal), (data_id, timestep)
            else:
                if current_actions is None:
                    current_actions = [record_list[-1]['adv_action']] * 32
                current_actions[timestep] = record['adv_action']
        max_reward = -1e5
        adv_collision = 0
        for timestep, record in enumerate(current_trajectory):
            if record['rew'] > max_reward:
                max_reward = record['rew']
            if record['cost'] != 0:
                adv_collision = 1
        stats = get_statistics_of_current_trajectory(record_list)
        stats['max_reward'] = max_reward
        stats['adv_collision'] = adv_collision
        current_obs = []
        current_done = []
        current_cost = []
        for i in range(len(current_actions) * 4 - len(current_trajectory) + 1):
            current_trajectory.append(current_trajectory[-1])
        for timestep in range(len(current_actions)):
            idx = timestep * 4
            current_obs.append(current_trajectory[idx]['obs'])
            current_done.append(current_trajectory[idx]['done'] or current_trajectory[idx + 1]['done'] or
                                current_trajectory[idx + 2]['done'] or current_trajectory[idx + 3]['done'])
            current_cost.append(min(current_trajectory[idx]['cost'], current_trajectory[idx + 1]['cost'],
                                current_trajectory[idx + 2]['cost'], current_trajectory[idx + 3]['cost']))
        current_obs.append(current_obs[-1])
        for i in range(len(current_actions)):
            dataset.add({'obs': current_obs[i],
                         'act': np.squeeze(current_actions[i]),
                         # 'rew': max_reward if i == 0 else 0,
                         'obs2': current_obs[i + 1],
                         'done': current_done[i],
                         'cost': current_cost[i]})
        dataset.on_episode_end(stats)
    dataset.dump()


def process_data_other_methods():
    file_name = '/home/carla/output/testing_records/record.pkl'
    dataset = DrivingRecord('/home/carla/output/testing_records', 's5_driving_raw_data_advsim.pkl')
    trajectories = joblib.load(file_name)

    advsim_ids = list(range(0, 100)) + list(range(200, 300)) + list(range(400, 500)) + list(range(600, 700)) + list(range(800, 900))
    advtraj_ids = list(range(100, 200)) + list(range(300, 400)) + list(range(500, 600)) + list(range(700, 800)) + list(range(900, 1000))

    for data_id, record_list in tqdm(trajectories.items()):
        if data_id in advtraj_ids:
            continue
        current_trajectory = []
        last_action = None
        current_actions = [last_action] * 128
        start_timestep = None
        for timestep, record in enumerate(record_list):
            if record['adv_obs'] is None:
                continue
            if start_timestep is None:
                start_timestep = timestep
            current_timestep = timestep - start_timestep
            next_timestep = timestep + 1 if timestep + 1 < len(record_list) else timestep
            obs2 = record_list[next_timestep]['adv_obs'] if record_list[next_timestep]['adv_obs'] is not None else record['adv_obs']
            current_trajectory.append({'obs': record['adv_obs'],
                                        'act': np.squeeze(record['adv_action']),
                                        'rew': record['adv_reward'],
                                        'obs2': obs2,
                                        'done': record['adv_done'],
                                        'cost': record['adv_cost']})
            if current_timestep % 4 == 3 and current_timestep < 128:
                actions = [record_list[timestep-3]['adv_action'], record_list[timestep-2]['adv_action'],
                           record_list[timestep-1]['adv_action'], record_list[timestep]['adv_action']]
                current_actions[current_timestep // 4] = np.mean(actions, axis=0)
                last_action = np.mean(actions, axis=0)
        for i in range(len(current_actions)):
            if current_actions[i] is None:
                current_actions[i] = last_action
        max_reward = -1e5
        adv_collision = 0
        for timestep, record in enumerate(current_trajectory):
            if record['rew'] > max_reward:
                max_reward = record['rew']
            if record['cost'] != 0:
                adv_collision = 1
        stats = get_statistics_of_current_trajectory(record_list)
        stats['max_reward'] = max_reward
        stats['adv_collision'] = adv_collision
        current_obs = []
        current_done = []
        current_cost = []
        for i in range(len(current_actions) * 4 - len(current_trajectory) + 1):
            current_trajectory.append(current_trajectory[-1])
        for timestep in range(len(current_actions)):
            idx = timestep * 4
            current_obs.append(current_trajectory[idx]['obs'])
            current_done.append(current_trajectory[idx]['done'] or current_trajectory[idx + 1]['done'] or
                                current_trajectory[idx + 2]['done'] or current_trajectory[idx + 3]['done'])
            current_cost.append(min(current_trajectory[idx]['cost'], current_trajectory[idx + 1]['cost'],
                                current_trajectory[idx + 2]['cost'], current_trajectory[idx + 3]['cost']))
        current_obs.append(current_obs[-1])
        for i in range(len(current_actions)):
            dataset.add({'obs': current_obs[i],
                         'act': np.squeeze(current_actions[i]),
                         # 'rew': max_reward if i == 0 else 0,
                         'obs2': current_obs[i + 1],
                         'done': current_done[i],
                         'cost': current_cost[i]})
        dataset.on_episode_end(stats)
    dataset.dump()


if __name__ == '__main__':
    process_data()
