import os
import torch

from core.config import BaseConfig
from core.utils import make_atari, WarpFrame, EpisodicLifeEnv
from core.dataset import Transforms
from .env_wrapper import AtariWrapper
from .model import EfficientZeroNet

import numpy as np


class AtariConfig(BaseConfig):
    def __init__(self, env_name):
        train_scale = 512 // 256
        scale = 0.2
        data_multi = 3
        weight_decay = 5e-5
        super(AtariConfig, self).__init__(
            training_steps=20000, # int(100000 * scale),
            last_steps=0, # int(20000 * scale),
            freeze_steps=int(2000 * scale),
            test_interval=int(10000 * scale),
            log_interval=100,
            vis_interval=100,
            test_episodes=32,
            checkpoint_interval=10,
            target_model_interval=40,
            refresh_interval=10,
            send_max_version_gap=10000,
            max_version_gap=50,
            save_ckpt_interval=1000,
            max_moves=108000,
            test_max_moves=12000,
            history_length=200, # 400,
            discount=0.997,
            dirichlet_alpha=0.3,
            value_delta_max=0.01,
            num_simulations=40, # 25
            reanalyze_num_simulations=40, # 25
            batch_size=128,
            td_steps=5,
            num_actors=data_multi,
            # network initialization/ & normalization
            episode_life=True,
            init_zero=True,
            clip_reward=True,
            # storage efficient
            cvt_string=False,
            image_based=True,
            # lr scheduler
            lr_warm_up=1000, # 1000, 
            lr_init=0.4,
            lr_decay_rate=0.1,
            lr_decay_steps=int(100000 * scale * 1.),
            auto_td_steps_ratio=0.3,
            weight_decay=weight_decay, # 5e-5
            weight_decay_change_step=0.5, 
            weight_decay_multi=2.,
            optimizer='sgd',
            # replay window
            start_transitions=data_multi * 0.1 * 100 * 1000, # int(2000 / scale),
            total_transitions=int(data_multi * 1. * 100 * 1000),
            keep_latest=100 * 100 * 1000,
            transition_num=1,
            # frame skip & stack observation
            frame_skip=4,
            stacked_observations=4,
            # coefficient
            reward_loss_coeff=1,
            value_loss_coeff=0.25,
            policy_loss_coeff=1,
            consistency_coeff=2,
            policy_kl_loss_coeff=0.,
            # policy
            # exploit_policy=0.3,
            # reward sum
            lstm_hidden_size=512,
            lstm_horizon_len=5,
            # siamese
            proj_hid=1024,
            proj_out=1024,
            pred_hid=512,
            pred_out=1024,
            # reanalyze batch size
            reanalyze_batch_size=128, # 128,
            # GAE
            gae_lambda = 0.95,
            init_gae=0,
            # virtual value
            virtual_gamma=0.,)
        self.discount **= self.frame_skip
        self.max_moves //= self.frame_skip
        self.test_max_moves //= self.frame_skip

        self.start_transitions = self.start_transitions # self.start_transitions * 1000 // self.frame_skip
        self.start_transitions = max(1, self.start_transitions)

        self.bn_mt = 0.1
        self.blocks = 1  # Number of blocks in the ResNet
        self.channels = 64  # Number of channels in the ResNet
        if self.gray_scale:
            self.channels = 32
        self.reduced_channels_reward = 16  # x36 Number of channels in reward head
        self.reduced_channels_value = 16  # x36 Number of channels in value head
        self.reduced_channels_policy = 16  # x36 Number of channels in policy head
        self.resnet_fc_reward_layers = [32]  # Define the hidden layers in the reward head of the dynamic network
        self.resnet_fc_value_layers = [32]  # Define the hidden layers in the value head of the prediction network
        self.resnet_fc_policy_layers = [32]  # Define the hidden layers in the policy head of the prediction network
        self.downsample = True  # Downsample observations before representation network (See paper appendix Network Architecture)

        self.num_groups = 16 # number of groups in GroupNorm
        self.normalization = 'bn'
        self.sync_batch_norm = False # True

        # training preload
        self.num_preload_batches = 1

        # reset model update interval
        self.reset_model = False
        self.reset_model_interval = 2500

    def visit_softmax_temperature_fn(self, num_moves, trained_steps):
        if self.change_temperature:
            '''if trained_steps < 0.5 * (self.training_steps + self.last_steps):
                return 1.0
            elif trained_steps < 0.75 * (self.training_steps + self.last_steps):
                return 0.5
            else:
                return 0.25'''
            if trained_steps < 0.5 * self.training_steps:
                return 1.0
            else:
                return 0.5
        else:
            return 1.0

    def set_game(self, env_name, save_video=False, save_path=None, video_callable=None):
        self.env_name = env_name
        # gray scale
        if self.gray_scale:
            self.image_channel = 1
        obs_shape = (self.image_channel, 96, 96)
        self.obs_shape = (obs_shape[0] * self.stacked_observations, obs_shape[1], obs_shape[2])

        game = self.new_game()
        self.action_space_size = game.action_space_size

    def get_uniform_network(self, is_trainer=False, is_reanalyze_worker=False, is_data_worker=False):
        net = EfficientZeroNet(
            self.obs_shape,
            self.action_space_size,
            self.blocks,
            self.channels,
            self.reduced_channels_reward,
            self.reduced_channels_value,
            self.reduced_channels_policy,
            self.resnet_fc_reward_layers,
            self.resnet_fc_value_layers,
            self.resnet_fc_policy_layers,
            self.reward_support.size,
            self.value_support.size,
            self.downsample,
            self.inverse_value_transform,
            self.inverse_reward_transform,
            self.lstm_hidden_size,
            bn_mt=self.bn_mt,
            proj_hid=self.proj_hid,
            proj_out=self.proj_out,
            pred_hid=self.pred_hid,
            pred_out=self.pred_out,
            init_zero=self.init_zero,
            state_norm=self.state_norm,
            num_groups=self.num_groups,
            normalization=self.normalization,
            is_trainer=is_trainer,
            is_reanalyze_worker=is_reanalyze_worker,
            is_data_worker=is_data_worker)
        if self.sync_batch_norm:
            net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
        return net

    def new_game(self, seed=None, save_video=False, save_path=None, video_callable=None, uid=None, test=False, final_test=False):
        if test:
            if final_test:
                max_moves = 108000 // self.frame_skip
            else:
                max_moves = self.test_max_moves
            env = make_atari(self.env_name, skip=self.frame_skip, max_episode_steps=max_moves)
        else:
            env = make_atari(self.env_name, skip=self.frame_skip, max_episode_steps=self.max_moves)

        if self.episode_life and not test:
            env = EpisodicLifeEnv(env, neg_rew_dead=False) # self.env_name.startswith("Pong"))
        env = WarpFrame(env, width=self.obs_shape[1], height=self.obs_shape[2], grayscale=self.gray_scale)

        if seed is not None:
            env.seed(seed)

        if save_video:
            from gym.wrappers import Monitor
            env = Monitor(env, directory=save_path, force=True, video_callable=video_callable, uid=uid)
        return AtariWrapper(env, discount=self.discount, cvt_string=self.cvt_string)

    def scalar_reward_loss(self, prediction, target):
        return -(torch.log_softmax(prediction, dim=1) * target).sum(1)

    def scalar_value_loss(self, prediction, target):
        return -(torch.log_softmax(prediction, dim=1) * target).sum(1)

    def set_transforms(self):
        if self.use_augmentation:
            self.transforms = Transforms(self.augmentation, image_shape=(self.obs_shape[1], self.obs_shape[2]))

    def transform(self, images):
        return self.transforms.transform(images)

    def get_init_avg_return(self, steps=20000):
        np.random.seed(25536)
        env = self.new_game(25536)
        returns = []
        total_step_count = 0
        while total_step_count < steps:
            rewards = []
            env.reset()
            done = False
            while not done:
                obs, ori_reward, done, info = env.step(np.random.randint(0, env.action_space_size))
                if self.clip_reward:
                    clip_reward = np.sign(ori_reward)
                    if self.env_name.startswith("Pong"):
                        if clip_reward < 0:
                            clip_reward = -1
                else:
                    clip_reward = ori_reward
                rewards.append(clip_reward)
                total_step_count += 1
            rewards = np.asarray(rewards)
            for i in range(len(rewards)):
                rew = rewards[i:]
                ret = (rew * (self.discount ** np.arange(len(rew)))).sum()
                returns.append(ret)
        self.avg_return = np.mean(returns)
        return self.avg_return

ATARI_ENV = os.environ['ATARI_SCENARIO']
game_config = AtariConfig(ATARI_ENV)
