import time
import os
import SMOS
import setproctitle
import torch

import numpy as np
import core.ctree.cytree as cytree

from torch.nn import L1Loss
from torch.cuda.amp import autocast as autocast
from core.mcts import MCTS
from core.game import GameHistory
from core.model import load_model_weights
from core.utils import select_action, prepare_observation_lst, get_gpu_memory

from core.storage_config import StorageConfig
from core.replay_buffer import get_replay_buffer
from core.shared_storage import get_shared_storage, read_weights
from core.meta_data_manager import MetaDataSharedMemoryManager


class DataWorker(object):
    def __init__(self, rank, replay_buffer, storage, smos_client, config, storage_config: StorageConfig):
        """Data Worker for collecting data through self-play
        Parameters
        ----------
        rank: int
            id of the worker
        replay_buffer: Any
            Replay buffer
        storage: Any
            The model storage
        """
        self.rank = rank
        self.config = config
        self.storage_config = storage_config
        self.storage = storage
        self.replay_buffer = replay_buffer
        self.smos_client = smos_client
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
        # double buffering when data is sufficient
        self.trajectory_pool = []
        self.pool_size = 1
        self.device = self.config.device
        self.gap_step = self.config.num_unroll_steps + self.config.td_steps
        self.last_model_index = -1

    def put(self, data):
        # put a game history into the pool
        self.trajectory_pool.append(data)

    def len_pool(self):
        # current pool size
        return len(self.trajectory_pool)

    def free(self, done = False):
        # save the game histories and clear the pool
        game_ids = []
        if self.len_pool() >= self.pool_size:
            game_ids = self.replay_buffer.save_pools(self.trajectory_pool, done=done)
            del self.trajectory_pool[:]
        return game_ids

    def put_last_trajectory(self, i, last_game_histories, last_game_priorities, game_histories):
        """put the last game history into the pool if the current game is finished
        Parameters
        ----------
        last_game_histories: list
            list of the last game histories
        last_game_priorities: list
            list of the last game priorities
        game_histories: list
            list of the current game histories
        """
        # pad over last block trajectory
        beg_index = self.config.stacked_observations
        end_index = beg_index + self.gap_step + 1

        pad_obs_lst = game_histories[i].obs_history[beg_index:end_index]

        beg_index = 0
        end_index = beg_index + self.config.num_unroll_steps + 1

        pad_child_visits_lst = game_histories[i].child_visits[beg_index:end_index]

        beg_index = 0
        end_index = beg_index + self.gap_step - 1 + 1

        pad_reward_lst = game_histories[i].rewards[beg_index:end_index]

        beg_index = 0
        end_index = beg_index + self.gap_step + 1 # use num_unroll_steps + td_steps for actions because actions would be used for length in reanalyze

        pad_action_lst = game_histories[i].actions[beg_index:end_index]

        beg_index = 0
        end_index = beg_index + self.gap_step + 1

        pad_root_values_lst = game_histories[i].root_values[beg_index:end_index]

        # pad over and save
        last_game_histories[i].pad_over(pad_obs_lst, pad_action_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst)
        last_game_histories[i].game_over(smos_client=self.smos_client, storage_config=self.storage_config)
        last_game_priorities[i] = last_game_histories[i].compute_priorities(self.config.discount, self.config.td_steps)

        self.put((last_game_histories[i], last_game_priorities[i]))
        game_ids = self.free()
        game_histories[i].last_game_id = game_ids[0]

        # reset last block
        last_game_histories[i] = None
        last_game_priorities[i] = None

    def get_priorities(self, i, pred_values_lst, search_values_lst):
        # obtain the priorities at index i
        if self.config.use_priority and not self.config.use_max_priority:
            pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.device).float()
            search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.device).float()
            priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + self.config.prioritized_replay_eps
        else:
            # priorities is None -> use the max priority for all newly collected data
            priorities = None

        return priorities

    def run(self):
        # number of parallel mcts
        env_nums = self.config.p_mcts_num
        model = self.config.get_uniform_network(is_data_worker=True)
        model.to(self.device)
        model.set_device(self.device)
        model.eval()

        start_training = False
        envs = [self.config.new_game(self.config.seed * 10000 + self.rank * 100 + i) for i in range(env_nums)]

        np.random.seed(self.config.seed * 10000 + self.rank + 1)

        def _get_max_entropy(action_space):
            p = 1.0 / action_space
            ep = - action_space * p * np.log2(p)
            return ep
        max_visit_entropy = _get_max_entropy(self.config.action_space_size)
        # 100k benchmark
        total_transitions = 0
        warmup_transitions = 0
        weights_version = 0
        # max transition to collect for this data worker
        max_transitions = self.config.total_transitions // self.config.num_actors
        with torch.no_grad():
            while True:
                trained_steps = self.storage.get_counter()
                # training finished
                if trained_steps >= self.config.training_steps + self.config.last_steps:
                    time.sleep(30)
                    break

                init_obses = [env.reset() for env in envs]
                dones = np.array([False for _ in range(env_nums)])
                game_histories = [GameHistory(max_length=self.config.history_length,
                                              config=self.config) for _ in range(env_nums)]
                last_game_histories = [None for _ in range(env_nums)]
                last_game_priorities = [None for _ in range(env_nums)]

                # stack observation windows in boundary: s398, s399, s400, current s1 -> for not init trajectory
                stack_obs_windows = [[] for _ in range(env_nums)]

                for i in range(env_nums):
                    stack_obs_windows[i] = [init_obses[i] for _ in range(self.config.stacked_observations)]
                    game_histories[i].init(stack_obs_windows[i])

                # for priorities in self-play
                search_values_lst = [[] for _ in range(env_nums)]
                pred_values_lst = [[] for _ in range(env_nums)]

                # some logs
                eps_ori_reward_lst, eps_reward_lst, eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums), np.zeros(env_nums), np.zeros(env_nums)
                step_counter = 0

                self_play_rewards = 0.
                self_play_ori_rewards = 0.
                self_play_moves = 0.
                self_play_episodes = 0.
                self_play_last_rewards = 0.
                self_play_last_moves = 0.
                self_play_last_ori_rewards = 0.

                self_play_rewards_max = - np.inf
                self_play_moves_max = 0

                self_play_visit_entropy = []
                other_dist = {}

                # play games until max moves
                while True:
                    if not start_training:
                        start_training = self.storage.get_start_signal()
                        warmup_transitions = total_transitions

                    # get model
                    trained_steps = self.storage.get_counter()
                    self.replay_buffer.update_version(trained_steps)
                    if trained_steps >= self.config.training_steps + self.config.last_steps:
                        # training is finished
                        time.sleep(30)
                        return
                    # if start_training and (total_transitions / max_transitions) > (trained_steps / self.config.training_steps):
                    #     # self-play is faster than training speed or finished
                    #     time.sleep(1)
                    #     continue
                    
                    if start_training:
                        if trained_steps <= self.config.freeze_steps:
                            # freeze for network warmup 
                            time.sleep(1)
                            continue
                        else:
                            rollout_ratio = ((total_transitions  - warmup_transitions)/ max_transitions)
                            train_ratio = ((trained_steps - self.config.freeze_steps) / (self.config.training_steps - self.config.freeze_steps))
                            expected_rollout_ratio = train_ratio # (lambda x: 0.3 * x * 2 if x <= 0.5 else 0.3 + 0.7 * (x - 0.5) * 2)(train_ratio)
                            print(f"[Selfplay worker] ratio: {rollout_ratio} {train_ratio} {expected_rollout_ratio}")
                            if rollout_ratio > expected_rollout_ratio:
                                # self-play is faster than training speed or finished
                                # print(f'[Self Play Worker] Sleep for 1 sec. Current {total_transitions} transitions at step {trained_steps}.')
                                time.sleep(0.1)
                                continue
                    else:
                        if self.replay_buffer.get_total_len() >= self.config.start_transitions:
                            time.sleep(0.1)
                            continue

                    # set temperature for distributions
                    _temperature = np.array(
                        [self.config.visit_softmax_temperature_fn(num_moves=0, trained_steps=trained_steps) for env in
                         envs])

                    # update the models in self-play every checkpoint_interval
                    # update the models only after training starts
                    new_model_index = (trained_steps - 5) // self.config.checkpoint_interval
                    if start_training and new_model_index > self.last_model_index:
                        self.last_model_index = new_model_index
                        # update model
                        weights, weights_version = read_weights(self.meta_data_manager, self.smos_client, self.storage, self.storage_config) # self.storage.get_weights()
                        set_weights = (not self.config.reset_model)
                        if self.config.reset_model:
                            if (weights_version <= self.config.reset_model_interval) or (self.config.reset_model_interval - weights_version % self.config.reset_model_interval <= 20):
                                set_weights = True
                        if set_weights:
                            model.set_weights(weights)
                            model.to(self.device)
                            model.eval()
                        del weights
                        print(f"[Data worker {self.rank}] Update to model of "
                              f"{new_model_index * self.config.checkpoint_interval}.")

                        # log if more than 1 env in parallel because env will reset in this loop.
                        if env_nums > 1:
                            if len(self_play_visit_entropy) > 0:
                                visit_entropies = np.array(self_play_visit_entropy).mean()
                                visit_entropies /= max_visit_entropy
                            else:
                                visit_entropies = 0.

                            if self_play_episodes > 0:
                                log_self_play_moves = self_play_moves / self_play_episodes
                                log_self_play_rewards = self_play_rewards / self_play_episodes
                                log_self_play_ori_rewards = self_play_ori_rewards / self_play_episodes
                            else:
                                log_self_play_moves = 0
                                log_self_play_rewards = 0
                                log_self_play_ori_rewards = 0

                            self.storage.set_data_worker_logs(self_play_last_moves, log_self_play_moves, self_play_moves_max,
                                                              self_play_last_ori_rewards, self_play_last_rewards, log_self_play_ori_rewards, log_self_play_rewards,
                                                              self_play_rewards_max, _temperature.mean(),
                                                              visit_entropies, 0,
                                                              other_dist)
                            self_play_rewards_max = - np.inf

                    step_counter += 1
                    for i in range(env_nums):
                        # reset env if finished
                        if dones[i]:

                            # pad over last block trajectory
                            if last_game_histories[i] is not None:
                                self.put_last_trajectory(i, last_game_histories, last_game_priorities, game_histories)

                            # store current block trajectory
                            game_histories[i].game_over(smos_client=self.smos_client,
                                                        storage_config=self.storage_config)
                            priorities = game_histories[i].compute_priorities(self.config.discount, self.config.td_steps)

                            self.put((game_histories[i], priorities))
                            self.free(done=False)

                            # reset the finished env and new a env
                            envs[i].close()
                            init_obs = envs[i].reset()
                            game_histories[i] = GameHistory(max_length=self.config.history_length,
                                                            config=self.config)
                            last_game_histories[i] = None
                            last_game_priorities[i] = None
                            stack_obs_windows[i] = [init_obs for _ in range(self.config.stacked_observations)]
                            game_histories[i].init(stack_obs_windows[i])

                            # log
                            self_play_last_rewards = eps_reward_lst[i]
                            self_play_last_ori_rewards = eps_ori_reward_lst[i]
                            self_play_last_moves = eps_steps_lst[i]
                            self_play_rewards_max = max(self_play_rewards_max, eps_reward_lst[i])
                            self_play_moves_max = max(self_play_moves_max, eps_steps_lst[i])
                            self_play_rewards += eps_reward_lst[i]
                            self_play_ori_rewards += eps_ori_reward_lst[i]
                            self_play_visit_entropy.append(visit_entropies_lst[i] / eps_steps_lst[i])
                            self_play_moves += eps_steps_lst[i]
                            self_play_episodes += 1

                            pred_values_lst[i] = []
                            search_values_lst[i] = []
                            # end_tags[i] = False
                            eps_steps_lst[i] = 0
                            eps_reward_lst[i] = 0
                            eps_ori_reward_lst[i] = 0
                            visit_entropies_lst[i] = 0

                    # stack obs for model inference
                    stack_obs = [game_history.step_obs() for game_history in game_histories]
                    if self.config.image_based:
                        stack_obs = prepare_observation_lst(stack_obs)
                        stack_obs = torch.from_numpy(stack_obs).to(self.device).float() / 255.0
                    else:
                        stack_obs = [game_history.step_obs() for game_history in game_histories]
                        stack_obs = torch.from_numpy(np.array(stack_obs)).to(self.device)

                    if self.config.amp_type == 'torch_amp':
                        with autocast():
                            network_output = model.initial_inference(stack_obs.float(), to_numpy=False)
                    else:
                        network_output = model.initial_inference(stack_obs.float(), to_numpy=False)
                    hidden_state_roots = network_output.hidden_state
                    reward_hidden_roots = network_output.reward_hidden
                    value_prefix_pool = network_output.value_prefix
                    policy_logits_pool = network_output.policy_logits
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()

                    if start_training:
                        roots = cytree.Roots(env_nums, self.config.action_space_size, self.config.num_simulations)
                        noises = [np.random.dirichlet([self.config.root_dirichlet_alpha] * self.config.action_space_size).astype(np.float32).tolist() for _ in range(env_nums)]
                        roots.prepare(self.config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool)
                        # do MCTS for a policy
                        MCTS(self.config, self.config.num_simulations).gpu_search(roots, model, hidden_state_roots, reward_hidden_roots, to_numpy=False)

                        roots_distributions = roots.get_distributions()
                        roots_values = roots.get_values()
                    else:
                        roots_values = np.zeros((env_nums,), dtype=np.float32)
                        time.sleep(0.0005)
                    for i in range(env_nums):
                        deterministic = False
                        if start_training:
                            distributions, value, temperature, env = roots_distributions[i], roots_values[i], _temperature[i], envs[i]
                        else:
                            # before starting training, use random policy
                            value, temperature, env = roots_values[i], _temperature[i], envs[i]
                            distributions = np.ones(self.config.action_space_size)

                        action, visit_entropy = select_action(distributions, temperature=temperature, deterministic=deterministic)
                        obs, ori_reward, done, info = env.step(action)
                        # clip the reward
                        if self.config.clip_reward:
                            clip_reward = np.sign(ori_reward)
                            if self.config.env_name.startswith("Pong"):
                                if clip_reward < 0:
                                    clip_reward = -1
                        else:
                            clip_reward = ori_reward

                        # store data
                        game_histories[i].store_search_stats(distributions, value, version=weights_version)
                        game_histories[i].append(action, obs, clip_reward)

                        eps_reward_lst[i] += clip_reward
                        eps_ori_reward_lst[i] += ori_reward
                        dones[i] = done
                        visit_entropies_lst[i] += visit_entropy

                        eps_steps_lst[i] += 1
                        total_transitions += 1

                        if self.config.use_priority and not self.config.use_max_priority and start_training:
                            pred_values_lst[i].append(network_output.value[i].item())
                            search_values_lst[i].append(roots_values[i])

                        # fresh stack windows
                        del stack_obs_windows[i][0]
                        stack_obs_windows[i].append(obs)

                        # if game history is full;
                        # we will save a game history if it is the end of the game or the next game history is finished.
                        if game_histories[i].is_full():
                            # pad over last block trajectory
                            if last_game_histories[i] is not None:
                                self.put_last_trajectory(i, last_game_histories, last_game_priorities, game_histories)

                            # calculate priority
                            priorities = self.get_priorities(i, pred_values_lst, search_values_lst)

                            # save block trajectory
                            last_game_histories[i] = game_histories[i]
                            last_game_priorities[i] = priorities

                            # new block trajectory
                            game_histories[i] = GameHistory(max_length=self.config.history_length,
                                                            config=self.config)
                            game_histories[i].init(stack_obs_windows[i])

def start_data_worker(rank, config, storage_config: StorageConfig):
    """
    Start a data worker. Call this method remotely.
    """
    # set the gpu it resides on
    time.sleep(0.1 * rank)
    available_memory_list = get_gpu_memory()
    for i in range(len(available_memory_list)):
        if i not in storage_config.data_worker_visible_devices:
            available_memory_list[i] = -1
    max_index = available_memory_list.index(max(available_memory_list))
    if available_memory_list[max_index] < 2000:
        print(f"[Data worker]******************* Warning: Low video ram (max remaining "
              f"{available_memory_list[max_index]}) *******************")
    used_gpu_idx = storage_config.data_worker_visible_devices[rank % len(storage_config.data_worker_visible_devices)]
    config.device = torch.device(torch.device(f"cuda:{used_gpu_idx}"))
    os.environ['CUDA_VISIBLE_DEVICES'] = str(used_gpu_idx)
    print(f"[Data worker] Data worker {rank} at process {os.getpid()}"
          f" will use GPU {used_gpu_idx}. Remaining memory before allocation {available_memory_list}")

    # set process name
    setproctitle.setproctitle(f"EfficientZero-dataworker{rank}")

    # get storages
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    shared_storage = get_shared_storage(storage_config=storage_config)
    smos_client = SMOS.Client(connection=storage_config.smos_connection)

    # start data worker
    data_worker = DataWorker(rank=rank, replay_buffer=replay_buffer, storage=shared_storage,
                             smos_client=smos_client, config=config, storage_config=storage_config)
    print(f"[Data worker] Start data worker {rank} at process {os.getpid()}.")
    data_worker.run()
