import torch
from collections import namedtuple
from copy import deepcopy

from .agent_base import AgentBase
from .replay_memory import ReplayMemoryTraj, ReplayMemory, ReplayMemoryMeta

Transition = namedtuple('Transition', ['s', 'a', 'r', 's_', 'done', 'info'])


class AgentOff():
    def __init__(self, CONFIG):

        # Training
        self.batch_size = CONFIG.BATCH_SIZE
        self.update_period = CONFIG.UPDATE_PERIOD

    def forward(self, state, append, latent=None):
        if self.eval_mode:
            a_all = self.learner.actor(state, append=append, latent=latent)
        else:
            a_all, _ = self.learner.actor.sample(state,
                                            append=append,
                                            latent=latent, 
                                            get_prob=False)
        return a_all


class AgentStepOff(AgentBase, AgentOff):
    def __init__(self, venv, CONFIG, CONFIG_ENV):
        AgentBase.__init__(self, venv, CONFIG, CONFIG_ENV)
        AgentOff.__init__(self, CONFIG)

        # Eval
        self.check_opt_freq = CONFIG.CHECK_OPT_FREQ
        self.num_episode_per_eval = CONFIG.NUM_EVAL_EPISODE
        self.success_threshold = CONFIG_ENV.SUCCESS_THRESHOLD

    # === Replay ===
    def sample_batch(self, batch_size=None, recent_size=None, **kwargs):
        if batch_size is None:
            batch_size = self.batch_size
        transitions, _ = self.memory.sample(batch_size, recent_size, **kwargs)
        batch = Transition(*zip(*transitions))
        return batch


    def store_transition(self, *args):
        self.memory.update(Transition(*args))


    def unpack_batch(self, batch, get_rnn_hidden=False):
        non_final_mask = torch.tensor(tuple(map(lambda s: not s, batch.done)),
                                      dtype=torch.bool).view(-1).to(
                                          self.device)
        non_final_state_nxt = torch.cat([
            s for done, s in zip(batch.done, batch.s_) if not done
        ]).to(self.device)
        state = torch.cat(batch.s).to(self.device)
        reward = torch.FloatTensor(batch.r).to(self.device)
        action = torch.cat(batch.a).to(self.device)

        # Optional
        append = None
        non_final_append_nxt = None
        if self.use_append:
            append = torch.cat([info['append']
                                for info in batch.info]).to(self.device)
            non_final_append_nxt = torch.cat([
                info['append_nxt'] for info in batch.info
            ]).to(self.device)[non_final_mask]
        hn = None
        cn = None
        if get_rnn_hidden:
            hn = batch.info[0]['hn'].to(self.device)  # only get initial, 1x
            cn = batch.info[0]['hn'].to(self.device)
        return (non_final_mask, non_final_state_nxt, state, action, reward,
                append, non_final_append_nxt, hn, cn)


class AgentMetaStepOff(AgentStepOff):
    def __init__(self, venv, CONFIG, CONFIG_ENV):
        super().__init__(venv, CONFIG, CONFIG_ENV)
    
        # Adapt
        self.num_itr = CONFIG.NUM_ITR
        self.num_adapt_task = CONFIG.NUM_ADAPT_TASK
        self.num_adapt_step = CONFIG.NUM_ADAPT_STEP
        self.num_adapt_episode = CONFIG.NUM_ADAPT_EPISODE # per step
        self.num_adapt_update = CONFIG.NUM_ADAPT_UPDATE
        self.adapt_online_weight = CONFIG.ADAPT_ONLINE_WEIGHT

        # Eval
        self.num_eval_task = CONFIG.NUM_EVAL_TASK

        # memory store task id and allow sampling for specific task - use info to specify
        self.memory = ReplayMemoryMeta(CONFIG.MEMORY_CAPACITY, CONFIG.FILL, CONFIG.SEED)


class AgentBaseStepOff(AgentStepOff):
    def __init__(self, venv, CONFIG, CONFIG_ENV):
        super().__init__(venv, CONFIG, CONFIG_ENV)
    
        # Sampling
        self.max_sample_steps = CONFIG.MAX_SAMPLE_STEPS
        self.opt_freq = CONFIG.OPTIMIZE_FREQ
        self.num_update_per_opt = CONFIG.UPDATE_PER_OPT
        self.min_step_b4_opt = CONFIG.MIN_STEPS_B4_OPT

        # Eval
        self.train_curve_interval = CONFIG.TRAIN_CURVE_INTERVAL

        # Memory for single-step transitions
        self.memory = ReplayMemory(CONFIG.MEMORY_CAPACITY, CONFIG.SEED)


    def get_learner_snapshot(self):
        old_memory = deepcopy(self.memory)
        old_policy = self.learner.get_model()
        old_optimizer = self.learner.get_optimizer()
        return old_memory, old_policy, old_optimizer


    def set_adapt_mode(self, optimizer=None):
        self.set_train_mode()

        # Reset memory
        self.memory = ReplayMemory(self.config.MEMORY_CAPACITY, 
                                   self.config.SEED)

        # Reset optimizer
        if optimizer:
            self.learner.restore_optimizer(optimizer)
        else:
            self.learner.build_optimizer()


    def restore_learner_snapshot(self, snapshot):
        self.memory = deepcopy(snapshot[0])
        self.learner.restore_model(snapshot[1])
        self.learner.restore_optimizer(snapshot[2])
