import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.ppo.utils import init

from ..ppo.model import NNBase, Flatten


class kukaNet_RSI3(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(kukaNet_RSI3, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(128, 24, 24)
            nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (128, 24, 24))->(128, 12, 12)
            nn.Conv2d(128, 256, 3, stride=2, padding=0), nn.ReLU(),  # (128, 12, 12)->(256, 5, 5)
            nn.Conv2d(256, 128, 3, stride=1, padding=0), nn.ReLU(),  # (256, 5, 5)->(128, 3, 3)
            Flatten()
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(5, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())

        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp=nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.fusionMlp=nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )


        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        robot_pose=x['robot_pose']
        image_feat=x['image_feat']

        motor_imgEmb=torch.cat([image_feat, robot_pose], dim=1)
        sound=x['goal_sound_feat']
        image=x['image']
        image=self.imgCNN(image)
        image_flatten=self.cnnMlp(image)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(image_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        sound=self.soundMlp(sound)
        fusion=sound+image_flatten

        fusion=self.fusionMlp(fusion)

        final_fusion=fusion+imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, {}


class kukaNet_RSI2(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(kukaNet_RSI2, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(128, 24, 24)
            nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (128, 24, 24))->(128, 12, 12)
            nn.Conv2d(128, 256, 3, stride=2, padding=0), nn.ReLU(),  # (128, 12, 12)->(256, 5, 5)
            nn.Conv2d(256, 128, 3, stride=1, padding=0), nn.ReLU(),  # (256, 5, 5)->(128, 3, 3)
            Flatten()
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(5, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())


        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )
        self.soundMlp=nn.Sequential(
            init_(nn.Linear(3, 128)), nn.ReLU(),
            init_(nn.Linear(128, 256)), nn.ReLU(),
            init_(nn.Linear(256, 256)), nn.ReLU(),
        )

        self.fusionMlp=nn.Sequential(
            init_(nn.Linear(256, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )


        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks, **kwargs):
        x = inputs
        batchSize=list(x.size())[0]
        motor_imgEmb=x[:,:5]
        sound=x[:,5:8]
        image=x[:,8:]
        image=torch.reshape(image,(batchSize, 3,96,96))
        image=self.imgCNN(image)
        image_flatten=self.cnnMlp(image)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(image_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        sound=self.soundMlp(sound)
        fusion=sound+image_flatten

        fusion=self.fusionMlp(fusion)

        final_fusion=fusion+imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, {}


class kukaNet_RSI1(NNBase):
    def __init__(self, num_inputs, config=None, recurrent=False, recurrentInputSize=128, recurrentSize=128, actionHiddenSize=128):
        super(kukaNet_RSI1, self).__init__(recurrent, recurrentInputSize, recurrentSize, actionHiddenSize)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.encoderNumLayer=1
        self.soundFrames=100
        self.soundMFCCFeat=40

        self.cached_sound_aux = None
        self.cached_context = None

        self.imgCNN = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
            nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
            nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
            nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.motorMlp = nn.Sequential(
            init_(nn.Linear(2, 64)), nn.ReLU(),
            init_(nn.Linear(64, 256)), nn.ReLU(),
        )
        self.cnnMlp = nn.Sequential(
            init_(nn.Linear(128 * 3 * 3, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU())


        self.imgMotorMlp=nn.Sequential(
            init_(nn.Linear(256, 64)), nn.ReLU(),
            init_(nn.Linear(64, recurrentInputSize)), nn.ReLU(),
        )
        self.imgMotorMlp2=nn.Sequential(
            init_(nn.Linear(recurrentSize, 256)), nn.ReLU(),
        )

        self.soundBiLSTM=nn.LSTM(input_size=512, hidden_size=512,
                                 num_layers=self.encoderNumLayer, batch_first=True, bidirectional=True)

        self.soundMlp = nn.Sequential(
            init_(nn.Linear(40, 977)), nn.ReLU(),
            init_(nn.Linear(977, 512)), nn.ReLU(),
        )
        self.attnMlp=nn.Sequential(
            init_(nn.Linear(2048, 512)), nn.ReLU(),
            init_(nn.Linear(512, 512)), nn.ReLU(),

        )
        self.vA=nn.Linear(512,1)
        self.contextMlp=nn.Sequential(
            init_(nn.Linear(1024, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),

        )
        self.contextMlp2 = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU(),

        )


        self.soundAux=nn.Sequential(
            init_(nn.Linear(128, 64)), nn.ReLU(),
            init_(nn.Linear(64, 4)),
        )


        self.fusionMlp=nn.Sequential(
            init_(nn.Linear(128*9, 512)), nn.ReLU(),
            init_(nn.Linear(512, 256)), nn.ReLU(),

        )

        self.mlp_all=nn.Sequential(
            init_(nn.Linear(256, 256)), nn.ReLU(),
            init_(nn.Linear(256, 128)), nn.ReLU(),
        )


        self.actor = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, actionHiddenSize)), nn.ReLU())

        self.critic = nn.Sequential(
            init_(nn.Linear(128, 128)), nn.ReLU(),
            init_(nn.Linear(128, 128)), nn.ReLU())

        self.critic_linear = init_(nn.Linear(128, 1))

        self.train()

    def sound_encoder(self, sound, batchSize):
        """
        It is the sound processing architecture used by RSI1
        """
        sound = self.soundMlp(sound)

        # Forward propagate LSTM
        # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        out, (h_n, c_n) = self.soundBiLSTM(sound)

        # attention
        h_n_concat = torch.cat([h_n[0], h_n[1]], dim=1)
        h_n_concat_expand = h_n_concat.unsqueeze(1).repeat([1, self.soundFrames, 1])
        h_n_out_concat = torch.cat([out, h_n_concat_expand], dim=2)
        h_n_out_concat = self.attnMlp(h_n_out_concat)
        score = F.softmax(self.vA(h_n_out_concat), dim=1)
        context = torch.bmm(torch.transpose(out, dim0=1, dim1=2), score).squeeze()
        context = self.contextMlp(context)

        self.cached_sound_aux = self.soundAux(context)
        self.cached_context = torch.unsqueeze(self.contextMlp2(context), -1)

    def forward(self, inputs, rnn_hxs, masks, infer=True, **kwargs):
        x = inputs
        batchSize=list(x.size())[0]
        motor_imgEmb=x[:,:2] 
        sound=x[:,2:self.soundFrames*self.soundMFCCFeat+2]
        sound=torch.reshape(sound, (batchSize, self.soundFrames,self.soundMFCCFeat))
        image=x[:,self.soundFrames*self.soundMFCCFeat+2:-4]
        image=torch.reshape(image,(batchSize, 3,96,96))
        img3=self.imgCNN(image)
        img3_flatten=torch.flatten(img3, start_dim=1)

        img_flatten=self.cnnMlp(img3_flatten)
        motor=self.motorMlp(motor_imgEmb)
        imageMotor=self.imgMotorMlp(img_flatten+motor)

        if self.is_recurrent:
            imageMotor, rnn_hxs = self._forward_gru(imageMotor, rnn_hxs, masks)

        imageMotorRnn=self.imgMotorMlp2(imageMotor)

        if infer: # assuming every episode has the same length and all env reset at the same time
            if not torch.isinf(sound).all(): # use sound encoder to update the cached result
                self.sound_encoder(sound, batchSize) 
        else: # notice that the data is arranged as [[env1's step 1], [env2's step1], [env1's step2], [env2's step2]...]
            row_sum=torch.sum(sound.view(batchSize, -1), dim=1)
            is_inf=torch.logical_not(row_sum.isinf()).nonzero().view(-1)
            self.sound_encoder(sound[is_inf,:], is_inf.size()[0])
            num_repeats=batchSize//len(is_inf) # assuming every episode has the same length and all env reset at the same time
            self.cached_sound_aux = self.cached_sound_aux.repeat(num_repeats, 1)
            self.cached_context = self.cached_context.repeat(num_repeats, 1, 1)

        image_reshape=torch.reshape(img3, (batchSize, 128, -1))
        fusion=self.cached_context+image_reshape

        fusion=torch.flatten(fusion, start_dim=1)
        fusion=self.fusionMlp(fusion)

        final_fusion=fusion+imageMotorRnn
        x=self.mlp_all(final_fusion)

        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)

        additional = {
            'sound_aux':self.cached_sound_aux
        }
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs, additional
