import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.ppo.utils import init

from ..ppo.model import NNBase

from models.vggnet import vgg11


class turtlebotNet_RSI3(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(turtlebotNet_RSI3, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))


        self.imgCNN = vgg11(pretrained=False, progress=False)


        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(5, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())


        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp = nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )

        self.fusionMlp = nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )


        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        robot_pose = x['robot_pose']
        image_feat = x['image_feat']
        motor_imgEmb = torch.cat([image_feat, robot_pose], dim=1)
        sound=x['goal_sound_feat']
        image=x['image']
        image = self.imgCNN(image)

        image_flatten=self.cnnMlp(image)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(image_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        sound=self.soundMlp(sound)
        fusion = sound + image_flatten
        exi = self.exi(fusion)[:, 0]

        fusion = self.fusionMlp(fusion)
        final_fusion = fusion + imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
            'exi_aux': exi
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional


class turtlebotNet_RSI2(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(turtlebotNet_RSI2, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))


        self.imgCNN = vgg11(pretrained=False, progress=False)


        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(5, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())


        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp = nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )


        self.fusionMlp = nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )

        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        batchSize=list(x.size())[0]
        motor_imgEmb=x[:,:5]
        sound=x[:,5:8]
        image=x[:,8:-1]
        image=torch.reshape(image,(batchSize, 3,96,96))
        image = self.imgCNN(image)

        image_flatten=self.cnnMlp(image)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(image_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)


        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        sound=self.soundMlp(sound)
        fusion = sound + image_flatten
        exi = self.exi(fusion)[:, 0]

        fusion = self.fusionMlp(fusion)
        final_fusion = fusion + imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
            'exi': exi
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional


class turtlebotNet_RSI1(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(turtlebotNet_RSI1, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.cached_sound_aux = None
        self.cached_context = None

        self.encoderNumLayer = 1
        self.soundFrames = 100
        self.soundMFCCFeat = 40

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(128, 24, 24)
            nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (128, 24, 24))->(128, 12, 12)
            nn.Conv2d(128, 256, 3, stride=2, padding=0), nn.ReLU(),  # (128, 12, 12)->(256, 5, 5)
            nn.Conv2d(256, 128, 3, stride=1, padding=0), nn.ReLU(),  # (256, 5, 5)->(128, 3, 3)
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(2, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())


        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )

        self.exi = nn.Sequential(
            nn.Linear(128 * 9, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )

        self.inSight = nn.Sequential(
            nn.Linear(128 * 3 * 3, 128), nn.ReLU(),
            nn.Linear(128, 4),
        )

        self.soundBiLSTM = nn.LSTM(input_size=512, hidden_size=512,
                                   num_layers=1, batch_first=True, bidirectional=True)

        self.soundMlp = nn.Sequential(
            nn.Linear(40, 977), nn.ReLU(),
            nn.Linear(977, 512), nn.ReLU(),
        )
        self.attnMlp = nn.Sequential(
            nn.Linear(2048, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),

        )
        self.vA = nn.Linear(512, 1)
        self.contextMlp = nn.Sequential(
            nn.Linear(1024, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
        )

        self.contextMlp2 = nn.Sequential(
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
        )

        self.soundAux = nn.Sequential(
            init_(nn.Linear(128, 64)), nn.ReLU(),
            init_(nn.Linear(64, 4)),
        )
        self.fusionMlp=nn.Sequential(
            init_(nn.Linear(128*9, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )


        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def sound_encoder(self, sound, batchSize):
        """
        It is the sound processing architecture used by RSI1
        """
        sound = self.soundMlp(sound)

        # Forward propagate LSTM
        # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        out, (h_n, c_n) = self.soundBiLSTM(sound)

        # attention
        h_n_concat = torch.cat([h_n[0], h_n[1]], dim=1)
        h_n_concat_expand = h_n_concat.unsqueeze(1).repeat([1, 100, 1])
        h_n_out_concat = torch.cat([out, h_n_concat_expand], dim=2)
        h_n_out_concat = self.attnMlp(h_n_out_concat)
        score = F.softmax(self.vA(h_n_out_concat), dim=1)
        context = torch.bmm(torch.transpose(out, dim0=1, dim1=2), score).squeeze()
        context = self.contextMlp(context)

        self.cached_sound_aux = self.soundAux(context)
        self.cached_context = torch.unsqueeze(self.contextMlp2(context), -1)


    def forward(self, inputs, rnn_hxs, masks, infer=True, **kwargs):
        x = inputs
        batchSize=list(x.size())[0]
        motor_imgEmb=x[:,:2] 
        sound=x[:,2:self.soundFrames*self.soundMFCCFeat+2]
        sound = torch.reshape(sound, (batchSize, self.soundFrames, self.soundMFCCFeat))
        image=x[:,self.soundFrames*self.soundMFCCFeat+2:-9]
        image=torch.reshape(image,(batchSize, 3,96,96))
        img3=self.imgCNN(image)
        img3_flatten=torch.flatten(img3, start_dim=1)
        inSight = self.inSight(img3_flatten)

        img_flatten=self.cnnMlp(img3_flatten)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(img_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        if infer:  # assuming every episode has the same length and all env reset at the same time
            if not torch.isinf(sound).all():  # use sound encoder to update the cached result
                self.sound_encoder(sound, batchSize)
        else:  # notice that the data is arranged as [[env1's step 1], [env2's step1], [env1's step2], [env2's step2]...]
            row_sum = torch.sum(sound.view(batchSize, -1), dim=1)
            is_inf = torch.logical_not(row_sum.isinf()).nonzero().view(-1)
            self.sound_encoder(sound[is_inf, :], is_inf.size()[0])
            num_repeats = batchSize // len(
                is_inf)  # assuming every episode has the same length and all env reset at the same time
            self.cached_sound_aux = self.cached_sound_aux.repeat(num_repeats, 1)
            self.cached_context = self.cached_context.repeat(num_repeats, 1, 1)

        image_reshape=torch.reshape(img3, (batchSize, 128, -1))
        fusion=self.cached_context+image_reshape

        fusion=torch.flatten(fusion, start_dim=1)
        exi = self.exi(fusion)[:, 0]
        fusion=self.fusionMlp(fusion)

        final_fusion=fusion+imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
            'inSight_aux': inSight,
            'sound_aux': self.cached_sound_aux,
            'exi_aux': exi
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional
