import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.ppo.utils import init

from ..ppo.model import NNBase, Flatten


class ai2thorNet_RSI3(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128,
                 actionHiddenSize=128):
        super(ai2thorNet_RSI3, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
            nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
            nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
            Flatten()
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.occupancyCNNMLP=nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1), nn.ReLU(),  # (1, 9, 9)->(32, 5, 5)
            nn.Conv2d(64, 32, 3, stride=2, padding=1), nn.ReLU(), # (32, 5, 5)->(32, 3, 3)
            Flatten(),
            nn.Linear(32*9, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU())


        self.motorMlp = nn.Sequential(
            init_(nn.Linear(3, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())

        self.imgMotorMlp = nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2 = nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp = nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )

        self.fusionMlp = nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all = nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )

        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        motor_imgEmb = x['image_feat']
        sound = x['goal_sound_feat']
        occupancy=x['occupancy']
        occupancy=self.occupancyCNNMLP(occupancy)
        image = x['image']

        image = self.imgCNN(image)

        image_flatten = self.cnnMlp(image)
        motor = self.motorMlp(motor_imgEmb)
        imageMotor = self.imgMotorMlp(image_flatten + motor+occupancy)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn = self.imgMotorMlp2(imageMotor)

        sound = self.soundMlp(sound)
        fusion = sound + image_flatten
        exi = self.exi(fusion)[:, 0]

        fusion = self.fusionMlp(fusion)
        final_fusion = fusion + imageMotorRnn
        x = self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional


class ai2thorNet_RSI2(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128,
                 actionHiddenSize=128):
        super(ai2thorNet_RSI2, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)


        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
            nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
            nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
            Flatten()
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.occupancyCNNMLP=nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1), nn.ReLU(),  # (1, 9, 9)->(32, 5, 5)
            nn.Conv2d(64, 32, 3, stride=2, padding=1), nn.ReLU(), # (32, 5, 5)->(32, 3, 3)
            Flatten(),
            nn.Linear(32*9, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU())

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(3, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())

        self.imgMotorMlp = nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2 = nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp = nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )

        self.fusionMlp = nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all = nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )

        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        motor_imgEmb = x['image_feat']
        sound = x['goal_sound_feat']
        occupancy = x['occupancy']
        occupancy = self.occupancyCNNMLP(occupancy)
        image = x['image']
        image = self.imgCNN(image)

        image_flatten = self.cnnMlp(image)
        motor = self.motorMlp(motor_imgEmb)
        imageMotor = self.imgMotorMlp(image_flatten + motor+occupancy)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn = self.imgMotorMlp2(imageMotor)

        sound = self.soundMlp(sound)
        fusion = sound + image_flatten
        exi = self.exi(fusion)[:, 0]

        fusion = self.fusionMlp(fusion)
        final_fusion = fusion + imageMotorRnn
        x = self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional


class ai2thorNet_RSI1(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(ai2thorNet_RSI1, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        self.config=config
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.encoderNumLayer=1
        if config.RLUseSoundLabel:
            self.soundFrames = 1
            self.soundMFCCFeat = self.config.taskNum
        else:
            self.soundFrames=600
            self.soundMFCCFeat=40

        self.cached_sound_aux = None
        self.cached_context = None

        self.imgBranch = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
            nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
            nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.occupancyCNNMLP = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1), nn.ReLU(),  # (1, 9, 9)->(32, 5, 5)
            nn.Conv2d(64, 32, 3, stride=2, padding=1), nn.ReLU(),  # (32, 5, 5)->(32, 3, 3)
            Flatten(),
            nn.Linear(32 * 9, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU())

        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())

        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )

        self.inSight = nn.Sequential(
            nn.Linear(128 * 3 * 3, 128), nn.ReLU(),
            nn.Linear(128, self.config.taskNum),
        )

        if config.RLUseSoundLabel:
            self.soundMlp = nn.Sequential(nn.Linear(self.config.taskNum, 128), nn.ReLU(),
                                          nn.Linear(128, 128), nn.ReLU(), )
        else: # raw sound featuers

            ##RSI2 AI2Thor
            self.rnn = torch.nn.GRU(input_size=64 * 7, hidden_size=512, batch_first=True, bidirectional=True)
            self.cnn = nn.Sequential(
                nn.Conv2d(1, 64, (11, 11), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (1, 600, 40)->(32, 300, 20)
                nn.Conv2d(64, 64, (11, 5), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (32, 300, 20)->(32, 150, 13)
                nn.Conv2d(64, 64, (7, 3), stride=(2, 2), padding=(1, 1)), nn.ReLU(),  # (32, 150, 13)->(32, 73, 7)
            )
            self.fc = nn.Sequential(nn.Linear(2 * 512, 128), nn.ReLU(),
                                          nn.Linear(128, 128), nn.ReLU(),)

            self.soundAux=nn.Sequential(
                init_(nn.Linear(128, 64)), nn.ReLU(),
                init_(nn.Linear(64, self.config.taskNum)),
            )

        self.fusionMlp=nn.Sequential(
            init_(nn.Linear(128*9, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(128 * 9, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )

        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, infer=True):
        x = inputs

        occupancy = x['occupancy']
        occupancy = self.occupancyCNNMLP(occupancy)
        image = x['image']
        batchSize = list(image.size())[0]

        img3=self.imgBranch(image)
        img3_flatten=torch.flatten(img3, start_dim=1)
        inSight = self.inSight(img3_flatten)

        img_flatten=self.cnnMlp(img3_flatten)
        imageMotor=self.imgMotorMlp(img_flatten+occupancy)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        if self.config.RLUseSoundLabel:
            sound = x['soundLabel']
            self.cached_context=self.soundMlp(sound)
            self.cached_context=torch.unsqueeze(self.cached_context, -1)
        else:
            sound = x['goal_sound']
            sound = torch.reshape(sound, (batchSize, self.soundFrames, self.soundMFCCFeat))
            if infer: # assuming every episode has the same length and all env reset at the same time
                if not torch.isinf(sound).all(): # use sound encoder to update the cached result
                    self.sound_encoder2(sound, batchSize) 
            else: # notice that the data is arranged as [[env1's step 1], [env2's step1], [env1's step2], [env2's step2]...]
                row_sum=torch.sum(sound.view(batchSize, -1), dim=1)
                is_inf=torch.logical_not(row_sum.isinf()).nonzero().view(-1)
                self.sound_encoder2(sound[is_inf,:], is_inf.size()[0])
                num_repeats=batchSize//len(is_inf) # assuming every episode has the same length and all env reset at the same time
                self.cached_sound_aux = self.cached_sound_aux.repeat(num_repeats, 1)
                self.cached_context = self.cached_context.repeat(num_repeats, 1, 1)

        image_reshape=torch.reshape(img3, (batchSize, 128, -1))
        fusion=self.cached_context+image_reshape

        fusion=torch.flatten(fusion, start_dim=1)
        exi = self.exi(fusion)[:, 0]
        fusion=self.fusionMlp(fusion)

        final_fusion=fusion+imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
            'inSight_aux': inSight,
            'exi_aux': exi
        }
        if not self.config.RLUseSoundLabel:
            additional['sound_aux']=self.cached_sound_aux

        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional

    def sound_encoder(self, sound, batchSize):
        """
        It is the sound processing architecture used by RSI1
        """
        sound = self.soundMlp(sound)

        # Forward propagate LSTM
        # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        out, (h_n, c_n) = self.soundBiLSTM(sound)

        # attention
        h_n_concat = torch.cat([h_n[0], h_n[1]], dim=1)
        h_n_concat_expand = h_n_concat.unsqueeze(1).repeat([1, self.soundFrames, 1])
        h_n_out_concat = torch.cat([out, h_n_concat_expand], dim=2)
        h_n_out_concat = self.attnMlp(h_n_out_concat)
        score = F.softmax(self.vA(h_n_out_concat), dim=1)
        context = torch.bmm(torch.transpose(out, dim0=1, dim1=2), score).squeeze()
        context = self.contextMlp(context)

        self.cached_sound_aux = self.soundAux(context)
        self.cached_context = torch.unsqueeze(self.contextMlp2(context), -1)

    def sound_encoder2(self, sound, batchSize): #RSI2 AI2Thor
        sound=torch.unsqueeze(sound, 1)
        cnn_out = self.cnn(sound)
        cnn_out = torch.reshape(torch.transpose(cnn_out, dim0=1, dim1=2), (-1, 73, 64 * 7))
        _, rnn_out = self.rnn(cnn_out)
        rnn_out = torch.cat((rnn_out[0, :, :], rnn_out[1, :, :]), dim=1)
        sound_feat = self.fc(rnn_out)

        self.cached_sound_aux = self.soundAux(sound_feat)
        self.cached_context = torch.unsqueeze(sound_feat, -1)
