from collections import OrderedDict

import numpy as np
import torch
from torch.nn import Module
from torch.nn import init
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import MSELoss as MSELoss

from train.reinforcment_learning.utils.lstm import LSTMLayer
from train.behavioral_cloning.models.LightCnnLstmOrigActs48 import CnnNet

class _Flatten(Module):
    def forward(self, x):
        return torch.flatten(x, 1)


class RR_LSTM_ARCH(Module):
    """A model for predicting the return from the past state and action."""

    def __init__(self, input_lstm, lstm_size=128, num_actions=10,
                 duplication=10, bias_mean=0, device=None):
        super(RR_LSTM_ARCH, self).__init__()
        self.duplication = duplication
        self.device = device
        self.cnn = CnnNet(in_frames=3)

        # lstm
        def get_init(mean: float = 0, std: float = 1):
            def _init(*args, **kwargs):
                return init.normal_(mean=mean, std=std, *args, **kwargs)
            return _init

        # JAM -> TODO ig and ci bias initialization should be high variance, for resources managing
        feature_size = input_lstm + num_actions * duplication
        self.lstm_size = lstm_size
        self.lstm = LSTMLayer(feature_size, lstm_size,
                              w_ci=(get_init(mean=0, std=0.1), False),
                              w_ig=(False, get_init(mean=0, std=0.1)),
                              w_og=False,
                              b_ci=get_init(mean=0),
                              b_ig=get_init(mean=-3),
                              b_og=False,
                              a_ci=lambda x: torch.tanh(x) * 8,  # * 4.,
                              a_out=lambda x: torch.tanh(x))

        # prediction layer
        # JAM -> TODO we need to initialize the bias to the mean of the lessons buffer
        self.final = torch.nn.Linear(lstm_size, 1)
 #       self.final.bias.data.normal_(bias_mean, 1)

    @property
    def resnet_size(self):
        return self.resnet.fc.out_features

    def reset_parameters(self, output_bias=None):
        # lstm
        self.lstm.__reset_parameters__()
        # output layer
        init.kaiming_uniform_(self.final.weight, nonlinearity='linear')
        if output_bias is None:
            init.zeros_(self.final.bias)
        else:
            self.final.bias.data.copy_(output_bias)

    def load_weights(self, filename):
        checkpoint = torch.load(filename)
        try:
            self.load_state_dict(checkpoint['state_dict'])
        except RuntimeError:
            # if state dict is missing values try initializing partial state
            state = self.state_dict()
            state.update(checkpoint['state_dict'])
            try:
                self.load_state_dict(state, strict=False)
            except RuntimeError as e:
                print("WARNING: {}".format(str(e)))
                pass

    def forward(self, states, actions):
        """
            states should be in shape [B, T, state_shape)
        """
        # what is this? why divide by the timesteps? - If we dont know how long the sequence, we need to divide by
        # seq_length to make sure the cell state does not get saturated
        # in case of fixed length sequences we dont have to worry about it
        states /= states.shape[1]
#        state = torch.from_numpy(states).to(self.device).type(torch.float32)
        lead_dim, T, B, shape = infer_leading_dims(states, 3)
        s = self.cnn(states.view(B * T, *shape))
        s_reshaped = s.view(B, T, -1)

        # duplicate and combine actions
        actions = actions.repeat(1, 1, self.duplication)
        features = torch.cat((s_reshaped, actions), dim=2)

        # what is return all seq pos?
        hs, _ = self.lstm.forward(features, return_all_seq_pos=True)
        return self.final(hs)


class RR_LSTM(Module):

    def __init__(self, model, buffer, return_scaling=1., loss_weight=1.,
                 lr=1e-3, l2_decay=0., num_actions=10, episode_len=512):
        super().__init__()
        self.return_scaling = return_scaling
        self.loss_weight = loss_weight
        self.n_actions = num_actions

        self.buffer = buffer
        # why call it critic?
        self.critic = model
        self.mse = torch.nn.MSELoss(reduction='none')
        self.optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=l2_decay)

        self.rr_count = 0
        self.lstm_updates = 0
        self.avg_prediction = 0.0
        self.episode_len = episode_len

    def redistribute_reward(self, states, actions):
        logits = self.critic(states, actions)
        prediction = torch.sigmoid(logits)

        # transform back
        if self.return_scaling:
            prediction = (prediction + 1) * self.return_scaling

        reward_diffs = prediction[:, 1:, 0] - prediction[:, :-1, 0]
        redistributed = torch.cat([prediction[:, :1, 0], reward_diffs], dim=1)
        self.rr_count += 1

        new_reward = redistributed[:, :-1]
        self.avg_prediction = self.avg_prediction * 0.9 + new_reward[:, 0] * 0.1
        new_reward[:, 0] -= self.avg_prediction
        return new_reward

    def losses(self, logits, returns, lengths):
        valid_mask = torch.ones_like(logits, dtype=torch.bool)
        for p, l in zip(valid_mask, lengths):
            if l > self.episode_len-1:
                l = self.episode_len
            p[l:].zero_()
        pred = torch.sigmoid(logits).squeeze()
        continuous_loss = self.mse(pred, returns.expand_as(pred))
        masked_loss = continuous_loss.masked_select(valid_mask)
        final_losses = torch.empty_like(continuous_loss[:, -1])
        # correcting for padding in the end
        for i, l in zip(range(len(lengths)), lengths):
            if l > self.episode_len - 1:
                l = self.episode_len
            final_losses[i] = continuous_loss[i, l - 1]

        aux_loss = masked_loss.mean()
        main_loss = final_losses.mean()
        return main_loss, aux_loss

    def update(self, writer, stop_loss, rudder_pretraining=False, batch_size=128):
        i = 0
        loss_average = 0.1
        print("Training Till loss is very low!")
        while loss_average > stop_loss:
            self.lstm_updates += 1
            # collect inputs
            if rudder_pretraining:
                states, actions, rewards, lengths, indices = self.buffer.sample(batch_size, randomize=True,
                                                                                balance=False)
            else:
                states, actions, rewards, lengths, indices = self.buffer.sample(batch_size, randomize=False,
                                                                                balance=True)
            states = torch.tensor(states, dtype=torch.float).to(self.critic.device).detach()
            actions = torch.tensor(actions, dtype=torch.float).to(self.critic.device).detach()
            lengths = lengths[:, 0]
            returns = np.sum(rewards, 1, keepdims=True)
            if self.return_scaling:
                returns = (returns / self.return_scaling - 1)
            returns = torch.tensor(returns, dtype=torch.float).to(self.critic.device)
            # forward pass
            self.rr_count += 1
            logits = self.critic(states, actions).squeeze()
            returns = returns.expand_as(logits)
            main_loss, aux_loss = self.losses(logits, returns, lengths)
            lstm_loss = main_loss + self.loss_weight * aux_loss
            self.buffer.loss[indices] = aux_loss.item()

            # backward pass
            self.optimizer.zero_grad()
            lstm_loss.backward(retain_graph=False)
            self.optimizer.step()

            # logging & stop criterion
            loss_np = lstm_loss.item()
            main_loss_np = main_loss.item()
            aux_loss_np = aux_loss.item()

            # directly adapt to big changes in loss
            loss_average -= 0.1 * (loss_average - (main_loss_np + aux_loss_np) / 2)
            if main_loss_np > loss_average * 2:
                loss_average = loss_np

            i += 1
            if i % 5 == 0:
                msg = "step {:03d} --- "
                msg += "main loss: {:.5f}, aux loss: {:.5f}, "
                msg += "loss: {:.5f}, avg loss: {:.5f}"
                print(msg.format(i, main_loss_np, aux_loss_np, loss_np, loss_average))
                writer.add_scalar('rudder/mainloss', main_loss_np, i)
                writer.add_scalar('rudder/auxloss', aux_loss_np, i)
                writer.add_scalar('rudder/loss', loss_np, i)
                writer.add_scalar('rudder/lossavg', loss_average, i)


def infer_leading_dims(tensor, dim):
    """Param 'dim': number of non-leading dimensions in tensor.
    Returns:
    lead_dim: int --number of leading dims found.
    T: int --size of first leading dim, if two leading dims, o/w 1.
    B: int --size of first leading dim if one, second leading dim if two, o/w 1.
    shape: tensor shape after leading dims.
    """
    lead_dim = tensor.dim() - dim
    assert lead_dim in (0, 1, 2)
    if lead_dim == 2:
        B, T = tensor.shape[:2]
    else:
        T = 1
        B = 1 if lead_dim == 0 else tensor.shape[0]
    shape = tensor.shape[-dim:]
    return lead_dim, T, B, shape
