import copy
import random
import json
import os
import numpy as np
import torch as th
import gymnasium as gym
import time
from collections import deque
from worker.gae import GAE
from worker.instance import Instance


def concat_state_and_style(states, style_state):
    style_state = np.array(style_state, dtype=np.float32)
    style_state = th.from_numpy(style_state)
    concat_state = th.cat([states, style_state], dim=0)
    return concat_state


def calculate_speed_reward(current_speed, target_speed, min_reward=-0.02, max_reward=0.02, max_diff=5):
    speed_diff = abs(current_speed - target_speed)
    normalized_diff = speed_diff / max_diff
    reward = 1 - normalized_diff

    scaled_reward = min_reward + (max_reward - min_reward) * reward
    scaled_reward = max(min_reward, min(max_reward, scaled_reward))

    return scaled_reward


class SamplerRollout(object):
    def __init__(
            self,
            sampler,
            agent,
            statistic=None
    ):
        self.sampler = sampler
        self.agent = agent
        self.statistic = statistic

        self.gae = GAE()
        self.done = True

        with open('config.json') as f:
            json_str = f.read()
        self.json_config = json.loads(json_str)
        self.gamma = self.json_config["gamma"]
        self.tau = self.json_config["tau"]
        self.action_num = self.json_config["action_dim"]
        self.traj_len = self.json_config["traj_len"]
        self.n_sample = self.json_config["n_sample"]
        self.env_name = self.json_config["env_name"]
        self.action_nums = self.json_config['action_dim']
        self.use_mas = self.json_config['use_mas']

        with open('env_config.json') as f:
            env_config = json.load(f)

        self.env = gym.make(self.env_name, config=env_config)

        # reward factor
        self.factor_dict = None

        self.info = {
            'prob': [],
            'styles': [],
        }

        self.current_state = None
        self.prev_info = None
        self.episode_return = 0
        self.episode_length = 0
        self.episode_action_list = []
        self.episode_y_list = []
        self.episode_reward_list = []
        self.episode_speed_reward_list = []
        self.episode_ttc_reward_list = []
        self.episode_change_reward_list = []
        self.episode_left_reward_list = []
        self.episode_right_reward_list = []
        self.episode_mid_reward_list = []

    def generate_reward_factors(self):
        style_dict = self.json_config['style_dict']
        factor_dict = {}
        reward_factors = {}
        style_state = []
        for factor_name, factor_info in style_dict.items():
            if factor_info['fine_grit'] > 0:
                if factor_info['lower_bound'] == factor_info['upper_bound']:
                    reward_factors[factor_name] = factor_info['lower_bound']
                else:
                    grit_factor = np.linspace(
                        factor_info['lower_bound'],
                        factor_info['upper_bound'],
                        int((factor_info['upper_bound'] - factor_info['lower_bound']) / factor_info['fine_grit']) + 1
                    )
                    reward_factors[factor_name] = random.choice(grit_factor)
            else:
                reward_factors[factor_name] = random.uniform(factor_info['lower_bound'], factor_info['upper_bound'])

            if factor_info['lower_bound'] == factor_info['upper_bound']:
                style_value = factor_info['upper_bound'] / 2
            else:
                style_value = (reward_factors[factor_name] - factor_info['lower_bound']) / \
                                (factor_info['upper_bound'] - factor_info['lower_bound'])
            style_state.append(style_value)

            factor_dict = {
                'reward_factors': reward_factors,
                'style_state': style_state
            }

        return factor_dict

    def get_input(self):
        # get the obs state
        a_obs = copy.deepcopy(self.current_state)
        self_norm = np.array([500, 12, 30, 10, 1, 1, 1, 1, 1])
        other_norm = np.array([200, 12, 30, 10, 1, 1, 1, 1, 1])
        a_obs[0] = a_obs[0] / self_norm
        a_obs[1:] = a_obs[1:] / other_norm
        states_ = a_obs.flatten()
        states_ = th.from_numpy(states_)

        # get the style state
        if self.use_mas:
            style_state = th.from_numpy(np.array(self.factor_dict['style_state']))
        else:
            style_state = th.from_numpy(np.array(self.json_config['style_values']))

        return states_, style_state

    def process_action(self, action):
        if len(action) > 1:
            # action = [act.squeeze(1).numpy() for act in action]
            action = action[0][0]
        else:
            action = action[0].squeeze(1).numpy()
        return action

    def step(self, joint_action):
        next_state, reward, done, truncated, info = self.env.step(joint_action)
        return next_state, reward, done, truncated, info

    def reward_shaping(self, action):
        car_obs = copy.deepcopy(self.current_state[0, :])
        y_pos = car_obs[1]
        x_speed = car_obs[2]

        if self.use_mas:
            speed_factor = self.factor_dict['reward_factors']['speed']
            ttc_factor = self.factor_dict['reward_factors']['ttc']
            change_factor = self.factor_dict['reward_factors']['change']
            left_factor = self.factor_dict['reward_factors']['left']
            right_factor = self.factor_dict['reward_factors']['right']
            mid_factor = self.factor_dict['reward_factors']['mid']
        else:
            speed_factor = self.json_config['style_values'][0]
            ttc_factor = self.json_config['style_values'][1]
            change_factor = self.json_config['style_values'][2]
            left_factor = self.json_config['style_values'][3]
            right_factor = self.json_config['style_values'][4]
            mid_factor = self.json_config['style_values'][5]

        speed_reward = calculate_speed_reward(x_speed, speed_factor * 10 + 20)

        ttc_car = []
        for i in range(1, len(self.current_state)):
            other_car_info = self.current_state[i, :]
            if abs(other_car_info[1]) < 0.2 and other_car_info[0] > 0:
                ttc_car.append(other_car_info[0] / x_speed)
        if ttc_factor > 0:
            if ttc_car:
                ttc_reward = (min(ttc_car) - 2) * ttc_factor * 0.01
            else:
                ttc_reward = ttc_factor * 0.02
        else:
            ttc_reward = 0

        if action in [0, 2]:
            change_reward = (change_factor - 0.5) * 0.02
        else:
            change_reward = 0

        left_reward = 0
        right_reward = 0
        mid_reward = 0
        if y_pos <= 0.5 and left_factor > 0:
            left_reward = left_factor * 0.02
        if y_pos >= 11.5 and right_factor > 0:
            right_reward = right_factor * 0.02
        if 0.5 < y_pos < 11.5 and mid_factor > 0:
            mid_reward = mid_factor * 0.02

        return speed_reward, ttc_reward, change_reward, left_reward, right_reward, mid_reward

    def reset(self):
        self.done = False
        self.episode_return = 0
        self.episode_length = 0
        self.episode_action_list = []
        self.episode_y_list = []
        self.episode_speed_reward_list = []
        self.episode_ttc_reward_list = []
        self.episode_change_reward_list = []
        self.episode_left_reward_list = []
        self.episode_right_reward_list = []
        self.episode_mid_reward_list = []

        self.factor_dict = self.generate_reward_factors()

        current_state, info = self.env.reset()
        return current_state, info

    def sample_one_traj(self):
        self.prev_info = None
        memory = []
        results = []
        num_steps = 0
        while num_steps < self.traj_len:
            if self.done:
                # reset the env
                self.sampler.logger.debug('reset env')
                self.current_state, info = self.reset()
                self.ep_start_time = time.time()
            else:
                state, style = self.get_input()  # (s)
                action, probs, log_prob, state_value = self.agent.get_model_result(state, style)
                self.statistic.append("state_max_prob", probs.max().numpy())
                self.episode_action_list.append(action[0])

                #
                x_speed = self.current_state[0, :][2]
                self.statistic.append("speed", x_speed)
                #
                y_pos = self.current_state[0, :][1]
                self.episode_y_list.append(y_pos)

                ttc_car = [10]
                for i in range(1, len(self.current_state)):
                    other_car_info = self.current_state[i, :]
                    if abs(other_car_info[1]) < 0.2 and other_car_info[0] > 0:
                        ttc_car.append(other_car_info[0] / x_speed)
                min_ttc = min(ttc_car)
                self.statistic.append("ttc", min_ttc)

                next_state, env_reward, done, truncated, info = self.step(action[0])
                speed_reward, ttc_reward, change_reward, left_reward, right_reward, mid_reward = \
                    self.reward_shaping(action[0])
                reward = \
                    env_reward + speed_reward + ttc_reward + change_reward + left_reward + right_reward + mid_reward

                self.current_state = next_state
                self.episode_return += reward
                self.episode_length += 1
                self.done = done or truncated
                ###########
                if self.done:
                    self.episode_speed_reward_list.append(speed_reward)
                    self.episode_ttc_reward_list.append(ttc_reward)
                    self.episode_change_reward_list.append(change_reward)
                    self.episode_left_reward_list.append(left_reward)
                    self.episode_right_reward_list.append(right_reward)
                    self.episode_mid_reward_list.append(mid_reward)
                    self.statistic.append("reward_speed", np.sum(self.episode_speed_reward_list))
                    self.statistic.append("reward_ttc", np.sum(self.episode_ttc_reward_list))
                    self.statistic.append("reward_change", np.sum(self.episode_change_reward_list))
                    self.statistic.append("reward_left", np.sum(self.episode_left_reward_list))
                    self.statistic.append("reward_right", np.sum(self.episode_right_reward_list))
                    self.statistic.append("reward_mid", np.sum(self.episode_mid_reward_list))

                    left_ratio = self.episode_action_list.count(0) / len(self.episode_action_list)
                    idle_ratio = self.episode_action_list.count(1) / len(self.episode_action_list)
                    right_ratio = self.episode_action_list.count(2) / len(self.episode_action_list)
                    fast_ratio = self.episode_action_list.count(3) / len(self.episode_action_list)
                    slow_ratio = self.episode_action_list.count(4) / len(self.episode_action_list)
                    self.statistic.append("act_left_ratio", left_ratio)
                    self.statistic.append("act_idle_ratio", idle_ratio)
                    self.statistic.append("act_right_ratio", right_ratio)
                    self.statistic.append("act_fast_ratio", fast_ratio)
                    self.statistic.append("act_slow_ratio", slow_ratio)

                    lane_left_ratio = len([x for x in self.episode_y_list if x <= 0.2]) / len(self.episode_y_list)
                    lane_right_ratio = len([x for x in self.episode_y_list if x >= 11.8]) / len(self.episode_y_list)
                    lane_mid_ratio = len([x for x in self.episode_y_list if 0.3 < x < 11.7]) / len(self.episode_y_list)
                    self.statistic.append("lane_left_ratio", lane_left_ratio)
                    self.statistic.append("lane_right_ratio", lane_right_ratio)
                    self.statistic.append("lane_mid_ratio", lane_mid_ratio)

                    # record the style state
                    self.info['styles'] = list(style.numpy())
                    self.prev_info = copy.deepcopy(self.info)
                    self.statistic.append("ep_return", self.episode_return)
                    self.statistic.append("ep_length", self.episode_length)
                    self.statistic.append("ep_time", time.time() - self.ep_start_time)
                    results.append(1)

                    print("Episode len: %d, Total reward: %f" % (self.episode_length, self.episode_return))

                num_steps += 1
                instance = Instance(state=state.numpy(),  # (s,)
                                    style=style.numpy(),
                                    action=np.array([action.item()], dtype=np.float32),  # scalar
                                    is_done=1 - int(self.done),  # bool
                                    reward=reward,  # scalar
                                    old_state_value=state_value.item(),  # scalar
                                    old_log_prob=log_prob.cpu().numpy())  # (action_head_num, )
                memory.append(instance)

        state_, style = self.get_input()
        _, _, _, state_value = self.agent.get_model_result(state_, style)  # (1, 1)
        bootstrap_value = state_value.squeeze().numpy()  # (1,)

        rewards_list = [j.reward for j in memory]
        mask_list = [j.is_done for j in memory]
        value_list = [j.old_state_value for j in memory]
        advantages, returns = self.gae.estimate_advantages(
            rewards_list, mask_list, value_list, self.gamma, self.tau, bootstrap_value)

        for index, item in enumerate(memory):
            item.advantage = advantages[index]
            item.q_value = returns[index]

        return memory, results, self.episode_return, self.prev_info
