import os, time, torch, traceback, shutil
import numpy as np
from UTIL.colorful import *
from config import GlobalConfig
from UTIL.tensor_ops import repeat_at
from ALGORITHM.common.rl_alg_base import RLAlgorithmBase
class AlgorithmConfig:
    # configuration, open to jsonc modification
    gamma = 0.99
    tau = 0.95
    train_traj_needed = 512
    hete_n_alive_frontend = 1
    TakeRewardAsUnity = False
    use_normalization = True
    wait_norm_stable = True
    add_prob_loss = False
    n_focus_on = 2
    n_entity_placeholder = 24

    load_checkpoint = False
    load_specific_checkpoint = ''

    # PPO part
    clip_param = 0.2
    ppo_epoch = 16
    n_pieces_batch_division = 1
    value_loss_coef = 0.1
    entropy_coef = 0.05
    max_grad_norm = 0.5
    clip_param = 0.2
    lr = 1e-4

    # prevent GPU OOM
    prevent_batchsize_oom = False
    gamma_in_reward_forwarding = False
    gamma_in_reward_forwarding_value = 0.99

    net_hdim = 24
    
    dual_conc = True

    n_agent = 'auto load, do not change'

    ConfigOnTheFly = True


    hete_n_net_placeholder = 5
    hete_rollbuffer_size = 6
    hete_rollbuffer_interval = 5
    hete_sel_exclude_frontend = True
    hete_thread_align = False
    
    entity_distinct = 'auto load, do not change'

    policy_resonance = False

    use_avail_act = True
    
    debug = False
    
def str_array_to_num(str_arr):
    out_arr = []
    buffer = {}
    for str in str_arr:
        if str not in buffer:
            buffer[str] = len(buffer)
        out_arr.append(buffer[str])
    return out_arr

def itemgetter(*items):
    # same with operator.itemgetter
    def g(obj): return tuple(obj[item] if item in obj else None for item in items)
    return g

class ReinforceAlgorithmFoundation(RLAlgorithmBase):
    def __init__(self, n_agent, n_thread, space, mcv=None, team=None):
        from .shell_env import ShellEnvWrapper, ActionConvertLegacy
        from .hete_net import HeteNet
        super().__init__(n_agent, n_thread, space, mcv, team)
        AlgorithmConfig.n_agent = n_agent
        n_actions = len(ActionConvertLegacy.dictionary_args)

        # change obs format, e.g., converting dead agent obs into NaN
        self.shell_env = ShellEnvWrapper(n_agent, n_thread, space, mcv, self, AlgorithmConfig, GlobalConfig.ScenarioConfig, self.team)
        if self.ScenarioConfig.EntityOriented: rawob_dim = self.ScenarioConfig.obs_vec_length
        else: rawob_dim = space['obs_space']['obs_shape']
            
        # self.StagePlanner, for policy resonance
        from .stage_planner import StagePlanner
        self.stage_planner = StagePlanner(mcv=mcv)

        # heterogeneous agent types
        agent_type_list = [a['type'] for a in GlobalConfig.ScenarioConfig.SubTaskConfig.agent_list]
        self.HeteAgentType = str_array_to_num(agent_type_list)
        hete_type = np.array(self.HeteAgentType)[self.ScenarioConfig.AGENT_ID_EACH_TEAM[team]]
        
        # initialize policy
        self.policy = HeteNet(rawob_dim=rawob_dim, n_action=n_actions, hete_type=hete_type, stage_planner=self.stage_planner)
        self.policy = self.policy.to(self.device)

        # initialize optimizer and trajectory (batch) manager
        from .ppo import PPO
        from .trajectory import BatchTrajManager
        self.trainer = PPO(self.policy, ppo_config=AlgorithmConfig, mcv=mcv)
        self.traj_manager = BatchTrajManager(
            n_env=n_thread, traj_limit=int(GlobalConfig.ScenarioConfig.MaxEpisodeStep),
            trainer_hook=self.trainer.train_on_traj)
        self.stage_planner.trainer = self.trainer

        # confirm that reward method is correct
        self.check_reward_type(AlgorithmConfig)

        # load checkpoints if needed
        self.load_model(AlgorithmConfig)

        # enable config_on_the_fly ability
        if AlgorithmConfig.ConfigOnTheFly:
            self._create_config_fly()


    def action_making(self, StateRecall, test_mode):
        # make sure hook is cleared
        assert ('_hook_' not in StateRecall)
        
        # read obs
        obs, threads_active_flag, avail_act, hete_pick, hete_type, gp_sel_summary, eprsn = \
            itemgetter('obs', 'threads_active_flag', 'avail_act', '_hete_pick_', '_hete_type_', '_gp_pick_', '_EpRsn_')(StateRecall)
            
        
        # make sure obs is right
        assert obs is not None, ('Make sure obs is ok')
        assert len(obs) == sum(threads_active_flag), ('check batch size')
        # make sure avail_act is correct
        if AlgorithmConfig.use_avail_act: assert avail_act is not None
        
        eprsn = repeat_at(eprsn, -1, self.n_agent)
        thread_index = np.arange(self.n_thread)[threads_active_flag]
        
        with torch.no_grad():
            action, value, action_log_prob = self.policy.act(obs=obs,
                                                             test_mode=test_mode,
                                                             avail_act=avail_act,
                                                             hete_pick=hete_pick,
                                                             hete_type=hete_type,
                                                             gp_sel_summary=gp_sel_summary,
                                                             thread_index=thread_index,
                                                             eprsn=eprsn,
                                                             )

        # vars named like _x_ are aligned, others are not!
        traj_framefrag = {
            "_SKIP_":        ~threads_active_flag,
            "value":         value,
            "hete_pick":     hete_pick,
            "hete_type":     hete_type,
            "gp_sel_summary": gp_sel_summary,
            "avail_act":     avail_act,
            "actionLogProb": action_log_prob,
            "obs":           obs,
            "action":        action,
        }
        if avail_act is not None: traj_framefrag.update({'avail_act':  avail_act})
        
        # deal with rollout later when the reward is ready, leave a hook as a callback here
        if not test_mode: StateRecall['_hook_'] = self.commit_traj_frag(traj_framefrag, req_hook = True)
        return action.copy(), StateRecall


    def interact_with_env(self, StateRecall):
        '''

        '''
        return self.shell_env.interact_with_env(StateRecall)


    def interact_with_env_genuine(self, StateRecall):
        '''

        '''
        # if not StateRecall['Test-Flag']: self.train()  # when needed, train!
        return self.action_making(StateRecall, StateRecall['Test-Flag'])

    def train(self):
        '''

        '''
        if self.traj_manager.can_exec_training():
            if self.stage_planner.can_exec_trainning():
                self.traj_manager.train_and_clear_traj_pool()
            else:
                self.traj_manager.clear_traj_pool()
                
            # read configuration
            if AlgorithmConfig.ConfigOnTheFly: self._config_on_fly()
            
            # 
            self.stage_planner.update_plan()




    def save_model(self, update_cnt, info=None):
        '''
     
        '''
        if not os.path.exists('%s/history_cpt/' % GlobalConfig.logdir): 
            os.makedirs('%s/history_cpt/' % GlobalConfig.logdir)

        # dir 1
        pt_path = '%s/model.pt' % GlobalConfig.logdir
        print绿('saving model to %s' % pt_path)
        torch.save({
            'policy': self.policy.state_dict(),
            'at_optimizer': self.trainer.at_optimizer.state_dict(),
            'ct_optimizer': self.trainer.ct_optimizer.state_dict(),
        }, pt_path)

        # dir 2
        info = str(update_cnt) if info is None else ''.join([str(update_cnt), '_', info])
        pt_path2 = '%s/history_cpt/model_%s.pt' % (GlobalConfig.logdir, info)
        shutil.copyfile(pt_path, pt_path2)

        print绿('save_model fin')



    def load_model(self, AlgorithmConfig):
        '''
            load model now
        '''

        if AlgorithmConfig.load_checkpoint:
            manual_dir = AlgorithmConfig.load_specific_checkpoint
            ckpt_dir = '%s/model.pt' % GlobalConfig.logdir if manual_dir == '' else '%s/%s' % (GlobalConfig.logdir, manual_dir)
            cuda_n = 'cpu' if 'cpu' in self.device else self.device
            strict = True
            
            cpt = torch.load(ckpt_dir, map_location=cuda_n)
            self.policy.load_state_dict(cpt['policy'], strict=strict)

            self.trainer.at_optimizer.load_state_dict(cpt['at_optimizer'])
            self.trainer.ct_optimizer.load_state_dict(cpt['ct_optimizer'])

            print黄('loaded checkpoint:', ckpt_dir)


    def process_framedata(self, traj_framedata):
        ''' 

        '''
        # strip info, since it is not array
        items_to_pop = ['info', 'Latest-Obs']
        for k in items_to_pop:
            if k in traj_framedata:
                traj_framedata.pop(k)
        # the agent-wise reward is supposed to be the same, so averge them
        if self.ScenarioConfig.RewardAsUnity:
            traj_framedata['reward'] = repeat_at(traj_framedata['reward'], insert_dim=-1, n_times=self.n_agent)
        # change the name of done to be recognised (by trajectory manager)
        traj_framedata['_DONE_'] = traj_framedata.pop('done')
        traj_framedata['_TOBS_'] = traj_framedata.pop(
            'Terminal-Obs-Echo') if 'Terminal-Obs-Echo' in traj_framedata else None
        # mask out pause thread
        traj_framedata = self.mask_paused_env(traj_framedata)
        # put the frag into memory
        self.traj_manager.feed_traj_framedata(traj_framedata)

    def mask_paused_env(self, frag):
        running = ~frag['_SKIP_']
        if running.all():
            return frag
        for key in frag:
            if not key.startswith('_') and hasattr(frag[key], '__len__') and len(frag[key]) == self.n_thread:
                frag[key] = frag[key][running]
        return frag


    def _create_config_fly(self):
        logdir = GlobalConfig.logdir
        self.input_file_dir = '%s/cmd_io.txt' % logdir
        if not os.path.exists(self.input_file_dir):
            with open(self.input_file_dir, 'w+', encoding='utf8') as f: f.writelines(["# Write cmd at next line: ", ""])

    def _config_on_fly(self):
        if not os.path.exists(self.input_file_dir): return

        with open(self.input_file_dir, 'r', encoding='utf8') as f:
            cmdlines = f.readlines()

        cmdlines_writeback = []
        any_change = False

        for cmdline in cmdlines:
            if cmdline.startswith('#') or cmdline=="\n" or cmdline==" \n":
                cmdlines_writeback.append(cmdline)
            else:
                any_change = True
                try:
                    print亮绿('[foundation.py] ------- executing: %s ------'%cmdline)
                    exec(cmdline)
                    cmdlines_writeback.append('# [execute successfully]\t'+cmdline)
                except:
                    print红(traceback.format_exc())
                    cmdlines_writeback.append('# [execute failed]\t'+cmdline)

        if any_change:
            with open(self.input_file_dir, 'w+', encoding='utf8') as f:
                f.writelines(cmdlines_writeback)
