import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as weight_init
import math


def fanin_init(size, fanin=None):
    """
    weight initializer known from https://arxiv.org/abs/1502.01852
    :param size:
    :param fanin:
    :return:
    """
    fanin = fanin or size[0]
    v = np.sqrt(2.0/fanin)
    return torch.Tensor(size).uniform_(-v, v)

class LSTMDyn(nn.Module):
    def __init__(self, input_size=3, hidden_layer_size=20,numRats=5):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)
        self.linearRat = nn.Linear(hidden_layer_size, numRats)

    def forward(self, input_seq):
        if torch.cuda.is_available():
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda(),
                            torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda())
        else:
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size),
                                torch.zeros(1, input_seq.shape[0], self.hidden_layer_size))

        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)
        predictionRat = self.linearRat(F.relu(lstm_out[:,-1,:]))

        return predictionRat



class LSTME(nn.Module):
    def __init__(self, input_size=2, hidden_layer_size=10, state_size=1):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)

        """for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.0)
            elif 'weight' in name:
                stdv = 1.0 / math.sqrt(.000001)
                for weight in self.parameters():
                    weight.data.uniform_(-stdv, stdv)
                    """


    def forward(self, input_seq,batch_size=128):
        if torch.cuda.is_available():
            self.hidden_cell = (torch.zeros(1,batch_size, self.hidden_layer_size).cuda(),
                                torch.zeros(1, batch_size, self.hidden_layer_size).cuda())
        else:
            self.hidden_cell = (torch.zeros(1, batch_size, self.hidden_layer_size),
                                torch.zeros(1, batch_size, self.hidden_layer_size))
        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)

        #lstm_out,perm_idx=torch.nn.utils.rnn.pad_packed_sequence(lstm_out,batch_first=True)

        #_,unperm_idx=perm_idx.sort(0)
        #lstm_out=lstm_out[unperm_idx]
        return lstm_out[:,-1,:]


class Critic(nn.Module):
    def __init__(self, Z_dim, action_dim, h1=50, h2=100,h3=20, eps=0.03,batch_size=128):
        super(Critic, self).__init__()
        torch.manual_seed(0)

        self.state_dim = Z_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(Z_dim, h1)
        self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())

        self.fc2 = nn.Linear(h1 + action_dim, h2)
        self.fc2.weight.data = fanin_init(self.fc2.weight.data.size())

        self.fc3 = nn.Linear(h2, 1)
        self.fc3.weight.data.uniform_(-eps, eps)

        """self.fc4 = nn.Linear(h3, 1)
        self.fc4.weight.data.uniform_(-eps, eps)
        """

        self.relu = nn.ReLU()

        self.lstmE=LSTME(input_size=3)
        self.numRat=5
        self.lstmRat = LSTMDyn(numRats=self.numRat)


    def forward(self, ZDyn, actionC,prevS_A,bestPGuess,batch_size=1):
        """
        return critic Q(s,a)
        :param state: state [n, state_dim] (n is batch_size)
        :param action: action [n, action_dim]
        :return: Q(s,a) [n, 1]
        """
        if torch.cuda.is_available():
            allZE=torch.zeros(len(prevS_A),10).cuda()
            allRat = torch.zeros(len(prevS_A), self.numRat).cuda()
            for i in range(len(prevS_A)):
                allZE[i,:]=self.lstmE(prevS_A[i].reshape(1,-1,3).float().cuda(),batch_size=batch_size)
                allRat[i,:]=self.lstmRat(prevS_A[i].reshape(1,-1,3).float().cuda())
        else:
            ZE = self.lstm(prevS_A)
        Z=torch.cat((allZE,allRat.detach(),bestPGuess.type(torch.FloatTensor).reshape(-1,1).cuda()), dim=1)
        Z = self.relu(self.fc1(Z))
        x = torch.cat((Z, actionC), dim=1)

        x = self.relu(self.fc2(x))
        #x = self.relu(self.fc3(x))
        x=self.fc3(x)

        return x,Z,allRat