import torch
import torch.optim as optim

import numpy as np
from collections import deque

from pytorch_rl import policies
from pytorch_rl.utils import ImgToTensor

from src.networks import *
from src.utils import MSEMasked


class MemoryArch:
    """
    Base class for all baseline memory architectures. The only methods that are required
    to be implemented to inherit are write, reconstruct and state. Everything else by default
    assumes that you have no trainable parameters in you memory.
    """
    def reset(self):
        """Architecture has to be reset at end of episode."""
        pass

    def detach(self):
        """stopping gradient flow for learnt methods."""
        pass

    def write(self, glimpse, obs_mask, minus_obs_mask):
        """
        write a new observation to memory architecture
        :param glimpse: a (batchsize, *obs_shape) input to be stored in memory
        :param obs_mask: mask of where the observation is taken from the state
        :param minus_obs_mask: negative of above
        """
        raise NotImplementedError

    def reconstruct(self):
        """attempt to reconstruct environment"""
        raise NotImplementedError

    def content(self):
        """return current state of memory"""
        raise NotImplementedError

    def step(self, action=None):
        """if the memory has a learned model for stepping unobserved environment."""
        pass

    def assign(self, other):
        """assign internal trainable modules from another object"""
        pass

    def lossbatch(self, **kwargs):
        """
        train the memory architecture on a batch of observations. (if applicable)
        :param kwargs: arguments for training.
        """
        pass

    def to(self, device):
        """
        move trainable modules to device
        :param device: cpu or gpus
        """
        pass

    def share_memory(self):
        """share trainable modules across threads"""
        pass

    def parameters(self):
        """all trainable parameters returns here"""
        return None

    def tosave(self):
        """return all trainable modules that should be saved"""
        return None

    def load(self, path):
        """load all trainable modules from a path"""
        pass

class DynamicMap(MemoryArch):
    def __init__(self, size, channels, env_size, env_channels, batchsize, device, nb_actions, mode):
        self.size = size
        self.channels = channels
        self.env_size = env_size
        self.env_channels = env_channels
        self.batchsize = batchsize
        self.device = device
        self.nb_actions = nb_actions
        self.mode = mode

        if env_size == size:
            self.write_model = MapWrite(env_channels, channels)
            self.reconstruction_model = MapReconstruct(channels, env_channels)
            self.step_model = MapStepConditional(channels, channels, nb_actions)
        elif env_size == 84 and size == 21:
            self.write_model = MapWrite_84_21(env_channels, channels)
            self.reconstruction_model = MapReconstruction_21_84(channels, env_channels)
            self.step_model = MapStepResidualConditional(channels, channels, nb_actions)
        else:
            raise ValueError('Unexpected values for size/env_size')
        self.blend_model = MapBlend(channels*2, channels)
        # # variational stuff
        if self.mode == 'variational':
            self.q = VariationalFunction_x_x(channels, nb_actions)
            self.q_optimizer = optim.Adam(self.q.parameters(), lr=1e-4)

    def reset(self):
        """
        reset the map to beginning of episode
        """
        self.map = torch.zeros((self.batchsize, self.channels, self.size, self.size)).to(self.device)

    def detach(self):
        """
        detaches underlying map, stopping gradient flow
        """
        self.map = self.map.detach()

    def maskobs2maskmap(self, mask, minus_mask):
        '''
        converts an observation mask to one of size of map
        '''
        if self.size == self.env_size:
            minus_map_mask = minus_mask
            map_mask = mask
        elif self.size == self.env_size // 2:
            minus_map_mask = nn.MaxPool2d(2, stride=2)(minus_mask)
            map_mask = 1 - minus_map_mask
        elif self.size == self.env_size // 4:
            minus_map_mask = nn.MaxPool2d(2, stride=2)(minus_mask)
            minus_map_mask = nn.MaxPool2d(2, stride=2)(minus_map_mask)
            map_mask = 1 - minus_map_mask
        map_mask = map_mask.detach()
        minus_map_mask = minus_map_mask.detach()
        return map_mask, minus_map_mask

    def write(self, glimpse, obs_mask, minus_obs_mask):
        """
        stores an incoming glimpse into the memory map
        :param glimpse: a (batchsize, *attention_dims) input to be stored in memory
        :param attn: indices of above glimpse in map coordinates (where to write)
        """
        # what to write
        w = self.write_model(glimpse)
        w = self.blend_model(w, self.map.clone())

        # write
        map_mask, minus_map_mask = self.maskobs2maskmap(obs_mask, minus_obs_mask)
        w *= map_mask
        self.map = self.map * minus_map_mask + w

        # cost of writing
        return w.abs().mean()

    def step(self, action=None):
        """
        uses the model to advance the map by a step
        """
        # only dynamic part of map is affected
        dynamic = self.step_model(self.map, action)
        cost = dynamic.abs().mean()
        self.map = self.map.clone() + dynamic
        return cost

    def reconstruct(self):
        """
        attempt to reconstruct the entire state image using current map
        """
        return self.reconstruction_model(self.map)

    def content(self):
        return self.map

    def lossbatch(self, state_batch, action_batch, reward_batch,
                  glimpse_agent, training_metrics, locs=None):
        mse = MSEMasked()
        mse_unmasked = nn.MSELoss()
        total_write_loss = 0
        total_step_loss = 0
        total_post_write_loss = 0
        total_post_step_loss = 0
        overall_reconstruction_loss = 0
        min_overall_reconstruction_loss = 1.
        total_q_loss = 0
        # initialize map
        self.reset()
        # get an empty reconstruction
        post_step_reconstruction = self.reconstruct()
        loss = 0
        seq_len = state_batch.size(0)
        batch_size = state_batch.size(1)
        # replicating the last reward as attention takes reward from NEXT step
        reward_batch = torch.cat((reward_batch, reward_batch[-1,:].unsqueeze(dim=0)), dim=0)
        # s = self.map.clone().detach()
        logsoftmax = torch.nn.LogSoftmax(dim=1)
        CEloss = torch.nn.CrossEntropyLoss()
        for t in range(seq_len):
            # pick locations of attention
            if locs == None:
                loc = glimpse_agent.step(self.map.detach(), random=False)
            else:
                loc = glimpse_agent.norm_and_clip(locs[t].cpu().numpy(), unraveled=True)
            obs_mask, minus_obs_mask = glimpse_agent.create_attn_mask(loc)
            post_step_loss = mse(post_step_reconstruction, state_batch[t], obs_mask)
            recontruction_loss = mse_unmasked(post_step_reconstruction, state_batch[t]).mean()
            # write new observation to map
            obs = state_batch[t] * obs_mask
            write_cost = self.write(obs, obs_mask, minus_obs_mask)
            # post-write reconstruction loss
            post_write_reconstruction = self.reconstruct()
            post_write_loss = mse(post_write_reconstruction, state_batch[t], obs_mask).mean()
            # step forward the internal map
            actions = action_batch[t].unsqueeze(dim=1)
            onehot_action = torch.zeros(batch_size, self.nb_actions).to(self.device)
            onehot_action.scatter_(1, actions, 1)
            step_cost = self.step(onehot_action)
            post_step_reconstruction = self.reconstruct()
            # reward glimpse agent
            if self.mode == 'variational':
                loc_indices = np.ravel_multi_index(np.transpose(loc), (self.env_size, self.env_size))
                variational_output = self.q(s, self.map.detach(), onehot_action)
                variational_reward = logsoftmax(variational_output).detach()
                variational_reward = variational_reward[range(self.batchsize), loc_indices]
                glimpse_agent.reward(variational_reward)
                self.q_optimizer.zero_grad()
                q_loss = CEloss(variational_output, torch.LongTensor(loc_indices).to(self.device))
                q_loss.backward()
                self.q_optimizer.step()
                total_q_loss += q_loss.item()
                # save state after step for variational reward
                s = self.map.clone().detach()
            elif self.mode == 'environmental':
                glimpse_agent.reward(reward_batch[t+1])
            elif self.mode == 'l2reward':
                glimpse_agent.reward(post_step_loss.detach())
            elif self.mode == 'follow':
                # doesn't matter here the reward as the glimpse agent is fixed
                glimpse_agent.reward(post_step_loss.detach())
            else:
                raise ValueError("Unknown mode for dynamic map")
            post_step_loss = post_step_loss.mean()
            # add up all losses
            loss += 0.01 * (write_cost + step_cost) + post_write_loss + post_step_loss
            # loss += 0.01 * (write_cost + step_cost) + post_step_loss
            total_write_loss += 0.01 * write_cost.item()
            total_step_loss += 0.01 * + step_cost.item()
            total_post_write_loss += post_write_loss.item()
            total_post_step_loss += post_step_loss.item()
            overall_reconstruction_loss += recontruction_loss.item()
            if t == 0:
                min_overall_reconstruction_loss = recontruction_loss.item()
            elif recontruction_loss.item() < min_overall_reconstruction_loss:
                min_overall_reconstruction_loss = recontruction_loss.item()
        # update the training metrics
        training_metrics['map/write_cost'].update(total_write_loss / seq_len)
        training_metrics['map/step_cost'].update(total_step_loss / seq_len)
        training_metrics['map/post_write'].update(total_post_write_loss / seq_len)
        training_metrics['map/post_step'].update(total_post_step_loss / seq_len)
        training_metrics['map/overall'].update(overall_reconstruction_loss / seq_len)
        training_metrics['map/min_overall'].update(min_overall_reconstruction_loss)
        training_metrics['q/loss'].update(total_q_loss / seq_len)
        return loss

    def assign(self, other):
        self.write_model = other.write_model
        self.step_model = other.step_model
        self.reconstruction_model = other.reconstruction_model
        self.blend_model = other.blend_model
        if hasattr(self, 'q'):
            self.q = other.q
        if hasattr(self, 'agent_step_model'):
            self.agent_step_model = other.agent_step_model

    def to(self, device):
        self.write_model.to(device)
        self.reconstruction_model.to(device)
        self.blend_model.to(device)
        self.step_model.to(device)
        if hasattr(self, 'q'):
            self.q.to(device)
        if hasattr(self, 'agent_step_model'):
            self.agent_step_model.to(device)

    def share_memory(self):
        self.write_model.share_memory()
        self.reconstruction_model.share_memory()
        self.blend_model.share_memory()
        self.step_model.share_memory()
        if hasattr(self, 'agent_step_model'):
            self.agent_step_model.share_memory()

    def parameters(self):
        allparams = list(self.write_model.parameters()) +\
                    list(self.reconstruction_model.parameters()) +\
                    list(self.blend_model.parameters()) +\
                    list(self.step_model.parameters())
        if hasattr(self, 'agent_step_model'):
            allparams += list(self.agent_step_model.parameters())
        return allparams

    def tosave(self):
        toreturn = {'write': self.write_model.state_dict(),
                    'blend': self.blend_model.state_dict(),
                    'step': self.step_model.state_dict(),
                    'reconstruct': self.reconstruction_model.state_dict(),}
        if hasattr(self, 'agent_step_model'):
            toreturn['agent step'] = self.agent_step_model.state_dict()
        if hasattr(self, 'q'):
            toreturn['q'] = self.q.state_dict()
        return toreturn

    def save(self, path):
        tosave = {
            'write': self.write_model,
            'blend': self.blend_model,
            'step': self.step_model,
            'reconstruct': self.reconstruction_model}
        if hasattr(self, 'agent_step_model'):
            tosave['agent step'] = self.agent_step_model
        torch.save(tosave, path)

    def load(self, path):
        models = torch.load(path, map_location='cpu')
        self.write_model.load_state_dict(models['write'])
        self.blend_model.load_state_dict(models['blend'])
        self.step_model.load_state_dict(models['step'])
        self.reconstruction_model.load_state_dict(models['reconstruct'])
        if hasattr(self, 'agent_step_model'):
            self.agent_step_model.load_state_dict(models['agent step'])
        # try:
        #     self.q.load_state_dict(models['q'])
        # except KeyError:
        #     pass
