import math
import time
from typing import List, Tuple

import numpy as np
import ray
import torch
from torch.cuda.amp import autocast as autocast
from gymnasium.utils import seeding

from core.config import BaseConfig
from core.mcts import SampledMCTS
from core.game import GameHistory
from core.utils import prepare_observation_lst, concat_with_zero_padding, LinearSchedule

import os

class ReanalyzeWorker(object):

    def __init__(self, rank: int, config: BaseConfig):
        """ReanalyzeWorker for reanalyzing targets
        receive the context from replay buffer and prepare training batches

        Parameters
        ----------
        rank: int
            id of the worker
        """
        self.rank = rank
        self.config = config
        self.np_random, _ = seeding.np_random(config.seed * 2000 + self.rank)
        zero_obs_shape = (config.stacked_observations, config.num_agents, *config.obs_shape)
        if self.config.image_based:
            self.zero_obs = np.zeros(zero_obs_shape, dtype=np.uint8)
        else:
            self.zero_obs = np.zeros(zero_obs_shape, dtype=np.float32)

        self.beta_schedule = LinearSchedule(config.training_steps + config.last_steps,
                                            initial_p=config.priority_prob_beta, final_p=1.0)

        self.device = 'cuda' if (config.reanalyze_on_gpu and torch.cuda.is_available()) else 'cpu'

        self.model = config.get_uniform_network()
        self.model.to(self.device)
        self.model.eval()
        self.last_model_index = -1

    def update_model(self, model_index, weights):
        self.model.set_weights(weights)
        self.last_model_index = model_index

    def _prepare_reward_value_re(
        self,
        indices: List[int],
        games: List[GameHistory],
        game_pos_lst: List[int],
        transitions_collected: int,
        use_adaptive: bool = True,
    ):
        """prepare reward and value targets with reanalyzing

        Parameters
        ----------
        indices: list
            transition index in replay buffer
        games: list
            list of game histories
        game_pos_lst: list
            transition index in game
        transitions_collected: int
            number of collected transitions
        use_adaptive: bool
            HyperNonzero: whether to use AdaptiveNode for action selection in MCTS
        """

        value_obs_lst, rewards_lst, traj_lens, legal_actions_lst = [], [], [], []
        value_mask = []
        td_steps_lst = []

        for idx, game, state_index in zip(indices, games, game_pos_lst):
            traj_len = len(game)
            traj_lens.append(traj_len)
            rewards_lst.append(game.rewards)

            delta_td = (transitions_collected - idx) // self.config.auto_td_steps
            td_steps = self.config.td_steps - delta_td
            td_steps = np.clip(td_steps, 1, self.config.td_steps).astype(np.intc)

            game_obs = game.obs(state_index + td_steps, self.config.num_unroll_steps)

            for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                td_steps_lst.append(td_steps)
                bootstrap_index = current_index + td_steps
                if bootstrap_index < traj_len:
                    value_mask.append(1)
                    legal_actions_lst.append(game.legal_actions[bootstrap_index])
                    beg_index = bootstrap_index - (state_index + td_steps)
                    end_index = beg_index + self.config.stacked_observations
                    obs = game_obs[beg_index:end_index]
                else:
                    value_mask.append(0)
                    legal_actions_lst.append(game.legal_actions[0])
                    obs = self.zero_obs
                value_obs_lst.append(obs)

        batch_values, batch_rewards = [], []

        value_obs_lst = prepare_observation_lst(value_obs_lst, self.config.image_based)
        if self.config.image_based:
            value_obs_tensor = torch.from_numpy(value_obs_lst).to(self.device).float() / 255.0
        else:
            value_obs_tensor = torch.from_numpy(value_obs_lst).to(self.device).float()
        with autocast():
            network_output = self.model.initial_inference(value_obs_tensor)

        if self.config.use_root_value:

            legal_actions_lst = np.asarray(legal_actions_lst)

            search_results = SampledMCTS(self.config, self.np_random).batch_search(
                self.model, network_output, legal_actions_lst, self.device, True, 1.0,
                use_adaptive=use_adaptive)
            value_lst = search_results.value.flatten()

        elif self.config.use_pred_value:
            value_lst = network_output.value.flatten()
        else:
            raise NotImplementedError

        value_lst = value_lst * np.power(self.config.discount, td_steps_lst)
        value_lst = value_lst * value_mask

        value_index = 0
        for traj_len, reward_lst, state_index in zip(traj_lens, rewards_lst, game_pos_lst):
            target_values = []
            target_rewards = []

            for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                bootstrap_index = current_index + td_steps_lst[value_index]
                for i, reward in enumerate(reward_lst[current_index:bootstrap_index]):
                    value_lst[value_index] += reward * self.config.discount ** i

                if current_index < traj_len:
                    target_values.append(value_lst[value_index])
                    target_rewards.append(reward_lst[current_index])
                else:
                    target_values.append(0.0)
                    target_rewards.append(0.0)
                value_index += 1

            batch_rewards.append(target_rewards)
            batch_values.append(target_values)

        batch_rewards = np.asarray(batch_rewards).reshape(self.config.batch_size, self.config.num_unroll_steps + 1)
        batch_values = np.asarray(batch_values).reshape(self.config.batch_size, self.config.num_unroll_steps + 1)
        return batch_rewards, batch_values

    def _prepare_reward_value_non_re(
        self,
        games: List[GameHistory],
        game_pos_lst: List[int],
    ):
        """prepare reward and value without reanalyzing, just return the value in self-play

        Parameters
        ----------
        games: list
            list of game histories
        game_pos_lst: list
            transition index in game
        """
        batch_values, batch_rewards = [], []

        for game, state_index in zip(games, game_pos_lst):
            target_values, target_rewards = [], []
            traj_len = len(game)
            reward_lst = game.rewards
            for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                if current_index < traj_len:
                    target_rewards.append(reward_lst[current_index])
                    bootstrap_index = current_index + self.config.td_steps
                    if bootstrap_index < traj_len:
                        if self.config.use_root_value:
                            bootstrap_value = game.root_values[bootstrap_index]
                        elif self.config.use_pred_value:
                            bootstrap_value = game.pred_values[bootstrap_index]
                        else:
                            raise NotImplementedError
                    else:
                        bootstrap_value = 0
                    for i, reward in enumerate(reward_lst[current_index:bootstrap_index]):
                        bootstrap_value += reward * self.config.discount ** i
                    target_values.append(bootstrap_value)
                else:
                    target_rewards.append(0.)
                    target_values.append(0.)
            batch_rewards.append(target_rewards)
            batch_values.append(target_values)

        batch_rewards = np.asarray(batch_rewards).reshape(self.config.batch_size, self.config.num_unroll_steps + 1)
        batch_values = np.asarray(batch_values).reshape(self.config.batch_size, self.config.num_unroll_steps + 1)
        return batch_rewards, batch_values

    def _prepare_policy_re(
        self,
        games: List[GameHistory],
        game_pos_lst: List[int],
        use_adaptive: bool = True,
    ):
        """prepare policy targets with reanalyzing

        Parameters
        ----------
        games: list
            list of game histories
        game_pos_lst: list
            transition index in game
        """

        policy_obs_lst, legal_actions_lst = [], []
        policy_mask = []

        B, K, N, A, C = (
            len(games),
            self.config.num_unroll_steps,
            self.config.num_agents,
            self.config.action_space_size,
            self.config.sampled_action_times,
        )

        for game, state_index in zip(games, game_pos_lst):
            traj_len = len(game)

            game_obs = game.obs(state_index, K)

            for current_index in range(state_index, state_index + K + 1):
                if current_index < traj_len:
                    policy_mask.append(True)
                    legal_actions_lst.append(game.legal_actions[current_index])
                    beg_index = current_index - state_index
                    end_index = beg_index + self.config.stacked_observations
                    obs = game_obs[beg_index:end_index]
                else:
                    policy_mask.append(False)
                    legal_actions_lst.append(game.legal_actions[0])
                    obs = self.zero_obs
                policy_obs_lst.append(obs)

        policy_obs_lst = prepare_observation_lst(policy_obs_lst, self.config.image_based)
        if self.config.image_based:
            policy_obs_tensor = torch.from_numpy(policy_obs_lst).to(self.device).float() / 255.0
        else:
            policy_obs_tensor = torch.from_numpy(policy_obs_lst).to(self.device).float()
        with autocast():
            network_output = self.model.initial_inference(policy_obs_tensor)

        legal_actions_lst = np.asarray(legal_actions_lst).reshape(B * (K + 1), N, A)

        search_results = SampledMCTS(self.config, self.np_random).batch_search(
            self.model, network_output, legal_actions_lst, self.device, add_noise=True, sampled_tau=1.0,
            use_adaptive=use_adaptive)

        batch_sampled_actions_re, batch_sampled_masks_re = concat_with_zero_padding(search_results.sampled_actions, C)
        batch_sampled_masks_re[~np.asarray(policy_mask)] = False

        batch_sampled_visit_counts_re, _ = concat_with_zero_padding(search_results.sampled_visit_count, C)
        batch_sampled_policies_re = batch_sampled_visit_counts_re / self.config.num_simulations
        batch_sampled_policies_re[~np.asarray(policy_mask)] = 0.
        assert batch_sampled_policies_re[~batch_sampled_masks_re].sum() == 0

        batch_sampled_imp_ratio, _ = concat_with_zero_padding(search_results.sampled_imp_ratio, C)
        batch_sampled_imp_ratio[~batch_sampled_masks_re] = 0.

        batch_sampled_qvalues_re, _ = concat_with_zero_padding(search_results.sampled_qvalues, C)
        batch_root_mcts_values_re = np.expand_dims(search_results.value, axis=-1)
        batch_root_pred_values_re = network_output.value

        batch_adaptive_theta_re = search_results.roots_adaptive_theta

        return (
            batch_sampled_actions_re,
            batch_sampled_policies_re,
            batch_sampled_imp_ratio,
            batch_sampled_masks_re,
            batch_sampled_qvalues_re,
            batch_root_mcts_values_re,
            batch_root_pred_values_re,
            batch_adaptive_theta_re,
        )

    def _prepare_policy_non_re(
        self,
        games: List[GameHistory],
        game_pos_lst: List[int]
    ):
        raise NotImplementedError

    def make_batch(
        self,
        buffer_context: Tuple[List[GameHistory], List[int], List[int], List[float], List[float]],
        transitions_collected: int
    ):
        """prepare the context of a batch

        Parameters
        ----------
        buffer_context : Any
            batch context from replay buffer
        transitions_collected: int
            number of collected transitions

        """

        game_lst, game_pos_lst, indices_lst, weights_lst = buffer_context

        obs_lst, action_lst, mask_lst = [], [], []
        future_return_lst, model_index_lst = [], []
        for game, state_index in zip(game_lst, game_pos_lst):
            _obs = game.obs(state_index, self.config.num_unroll_steps, padding=True)
            _actions = game.actions[state_index:state_index + self.config.num_unroll_steps + 1].tolist()
            _mask = [1] * len(_actions)

            _actions += [[0] * self.config.num_agents for _ in range(self.config.num_unroll_steps + 1 - len(_actions))]
            _mask += [0] * (self.config.num_unroll_steps + 1 - len(_mask))

            obs_lst.append(_obs)
            action_lst.append(_actions)
            mask_lst.append(_mask)
            future_return_lst.append(np.sum(game.rewards[state_index:]))
            model_index_lst.append(game.model_indices[state_index])
        obs_lst = prepare_observation_lst(obs_lst, self.config.image_based)

        inputs_batch = [obs_lst, action_lst, mask_lst, indices_lst, weights_lst]
        for i in range(len(inputs_batch)):
            inputs_batch[i] = np.asarray(inputs_batch[i])

        use_adaptive = self.last_model_index >= self.config.distillation_warmup_steps
        if self.config.use_reanalyze_value:
            batch_rewards, batch_values = self._prepare_reward_value_re(
                indices_lst, game_lst, game_pos_lst, transitions_collected, use_adaptive=use_adaptive)
        else:
            batch_rewards, batch_values = self._prepare_reward_value_non_re(game_lst, game_pos_lst)

        B, K, N, C = (
            len(indices_lst),
            self.config.num_unroll_steps,
            self.config.num_agents,
            self.config.sampled_action_times,
        )

        re_num = math.ceil(B * self.config.revisit_policy_search_rate)

        batch_sampled_actions = np.empty((0, C, N), dtype=np.intc)
        batch_sampled_policies = np.empty((0, C))
        batch_sampled_imp_ratio = np.empty((0, C))
        batch_sampled_masks = np.empty((0, C), dtype=np.bool_)
        batch_sampled_qvalues = np.empty((0, C))
        batch_root_mcts_values = np.empty((0, 1))
        batch_root_pred_values = np.empty((0, 1))

        batch_adaptive_theta = np.empty((0, N, self.config.action_space_size), dtype=np.float32)

        if re_num > 0:

            (
                batch_sampled_actions_re,
                batch_sampled_policies_re,
                batch_sampled_imp_ratio_re,
                batch_sampled_masks_re,
                batch_sampled_qvalues_re,
                batch_root_mcts_values_re,
                batch_root_pred_values_re,
                batch_adaptive_theta_re,
            ) = self._prepare_policy_re(game_lst[:re_num], game_pos_lst[:re_num], use_adaptive=use_adaptive)
            batch_sampled_actions = np.concatenate([batch_sampled_actions, batch_sampled_actions_re])
            batch_sampled_policies = np.concatenate([batch_sampled_policies, batch_sampled_policies_re])
            batch_sampled_imp_ratio = np.concatenate([batch_sampled_imp_ratio, batch_sampled_imp_ratio_re])
            batch_sampled_masks = np.concatenate([batch_sampled_masks, batch_sampled_masks_re])
            batch_sampled_qvalues = np.concatenate([batch_sampled_qvalues, batch_sampled_qvalues_re])
            batch_root_mcts_values = np.concatenate([batch_root_mcts_values, batch_root_mcts_values_re])
            batch_root_pred_values = np.concatenate([batch_root_pred_values, batch_root_pred_values_re])
            batch_adaptive_theta = np.concatenate([batch_adaptive_theta, batch_adaptive_theta_re])

        if re_num < B:
            raise NotImplementedError

        batch_sampled_adv = batch_sampled_qvalues - batch_root_pred_values
        adv_copy = np.array(batch_sampled_adv)
        adv_copy[~batch_sampled_masks] = np.nan
        adv_mean = np.nanmean(adv_copy)
        adv_std = np.nanstd(adv_copy)
        batch_sampled_adv = (batch_sampled_adv - adv_mean) / (adv_std + 1e-5)
        batch_sampled_adv[~batch_sampled_masks] = 0.

        batch_sampled_actions = batch_sampled_actions.reshape(B, K + 1, C, N)
        batch_sampled_policies = batch_sampled_policies.reshape(B, K + 1, C)
        batch_sampled_imp_ratio = batch_sampled_imp_ratio.reshape(B, K + 1, C)
        batch_sampled_adv = batch_sampled_adv.reshape(B, K + 1, C)
        batch_sampled_masks = batch_sampled_masks.reshape(B, K + 1, C)

        A = self.config.action_space_size
        batch_adaptive_theta = batch_adaptive_theta.reshape(B, K + 1, N, A)

        batch_policies = (batch_sampled_actions, batch_sampled_policies,
                          batch_sampled_imp_ratio, batch_sampled_adv, batch_sampled_masks)
        targets_batch = (batch_rewards, batch_values, batch_policies)

        info = {
            'batch_future_return': np.mean(future_return_lst),
            'batch_model_index': np.mean(model_index_lst),
            'target_model_index': self.last_model_index,
            'theta_updated': batch_adaptive_theta,
        }

        batch_data = (inputs_batch, targets_batch, info)
        return batch_data

    def make_renalyze_update(self, game_id: int, game: GameHistory):

        policy_obs_lst = []
        game_obs = game.obs(0, len(game))
        for state_index in range(len(game)):
            policy_obs_lst.append(game_obs[state_index:state_index + self.config.stacked_observations])

        policy_obs_lst = prepare_observation_lst(policy_obs_lst, self.config.image_based)
        if self.config.image_based:
            policy_obs_tensor = torch.from_numpy(policy_obs_lst).to(self.device).float() / 255.0
        else:
            policy_obs_tensor = torch.from_numpy(policy_obs_lst).to(self.device).float()
        with autocast():
            network_output = self.model.initial_inference(policy_obs_tensor)

        legal_actions_lst = game.legal_actions

        use_adaptive = self.last_model_index >= self.config.distillation_warmup_steps
        search_results = SampledMCTS(self.config, self.np_random).batch_search(
            self.model, network_output, legal_actions_lst, self.device, add_noise=True, sampled_tau=1.0,
            use_adaptive=use_adaptive)

        C = self.config.sampled_action_times

        batch_sampled_actions, batch_sampled_masks = concat_with_zero_padding(search_results.sampled_actions, C)
        batch_sampled_visit_counts, _ = concat_with_zero_padding(search_results.sampled_visit_count, C)
        batch_sampled_policies = batch_sampled_visit_counts / self.config.num_simulations
        batch_sampled_qvalues, _ = concat_with_zero_padding(search_results.sampled_qvalues, C)
        batch_root_values = np.asarray(search_results.value)

        game.model_indices = np.array([self.last_model_index for _ in range(len(game))])
        game.sampled_actions = batch_sampled_actions
        game.sampled_policies = batch_sampled_policies
        game.sampled_padding_masks = batch_sampled_masks
        game.sampled_qvalues = batch_sampled_qvalues
        game.root_values = batch_root_values

        return (game_id, game)

@ray.remote
class RemoteReanalyzeWorker(ReanalyzeWorker):

    def __init__(self, rank, config, shared_storage, replay_buffer, batch_storage):
        """ReanalyzeWorker for reanalyzing targets on remote

        Parameters
        ----------
        rank: int
            id of the worker
        shared_storage: Any
            The model storage
        replay_buffer: Any
            Replay buffer
        batch_storage: Any
            The batch storage (batch queue)
        """
        super().__init__(rank, config)
        self.shared_storage = shared_storage
        self.replay_buffer = replay_buffer
        self.batch_storage = batch_storage

        if torch.cuda.is_available():
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)

    def get_beta(self, trained_steps):
        return self.beta_schedule.value(trained_steps)

    def run_loop(self):
        start = False
        while True:

            if not start:
                start = ray.get(self.shared_storage.get_start_signal.remote())

                if start:
                    trained_steps = ray.get(self.shared_storage.get_counter.remote())
                    beta = self.beta_schedule.value(trained_steps)
                    buffer_context_handle = self.replay_buffer.prepare_batch_context.remote(self.config.batch_size, beta)
                time.sleep(0.1)
                continue

            trained_steps = ray.get(self.shared_storage.get_counter.remote())
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break

            if self.last_model_index // self.config.target_model_interval < trained_steps // self.config.target_model_interval:
                target_model_index, target_weights = ray.get(self.shared_storage.get_target_weights.remote())
                self.model.load_state_dict(target_weights)
                self.model.to(self.device)
                self.model.eval()
                self.last_model_index = target_model_index

            if self.batch_storage.get_len() < self.batch_storage.threshold:

                buffer_context = ray.get(buffer_context_handle)
                transitions_collected = ray.get(self.replay_buffer.transitions_collected.remote())
                beta = self.beta_schedule.value(trained_steps)
                buffer_context_handle = self.replay_buffer.prepare_batch_context.remote(self.config.batch_size, beta)

                batch_context = self.make_batch(buffer_context, transitions_collected)
                self.batch_storage.push(batch_context)
            else:
                time.sleep(1)

    def update_loop(self):
        start = False
        while True:

            if not start:
                start = ray.get(self.shared_storage.get_start_signal.remote())

                if start:
                    game_handle = self.replay_buffer.prepare_game.remote()
                time.sleep(0.1)
                continue

            trained_steps = ray.get(self.shared_storage.get_counter.remote())
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break

            if self.last_model_index // self.config.target_model_interval < trained_steps // self.config.target_model_interval:
                target_model_index, target_weights = ray.get(self.shared_storage.get_target_weights.remote())
                self.model.load_state_dict(target_weights)
                self.model.to(self.device)
                self.model.eval()
                self.last_model_index = target_model_index

            game_id, game = ray.get(game_handle)
            game_handle = self.replay_buffer.prepare_game.remote()
            update_context = self.make_renalyze_update(game_id, game)
            self.replay_buffer.update_game_history.remote(update_context)
