import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from qpth.qp import QPFunction, QPSolvers
import numpy as np
from my_classes import test_solver as solver
import kerasncp as kncp
from kerasncp.torch import LTCCell


class CNN_NCP_BN(nn.Module):
    def __init__(
            self,
            conv_filters,
            q_size=2,
            p_size=2,
            use_pq_linear=False,  # If False, use NCP raw output as p and q
            inter_neurons=12,  # Number of inter neurons
            command_neurons=8,  # Number of command neurons
            motor_neurons=4,  # Number of motor neurons
            sensory_fanout=4,  # How many outgoing synapses has each sensory neuron
            inter_fanout=4,  # How many outgoing synapses has each inter neuron
            recurrent_command_synapses=4,  # Now many recurrent synapses are in the command neuron layer
            motor_fanin=6,  # How many incomming syanpses has each motor neuron
            dropout=0.,
            dev='cpu'):
        super(CNN_NCP_BN, self).__init__()

        self.device = dev
        self.p1 = 0
        self.p2 = 0

        if not use_pq_linear:
            assert (p_size + q_size) == motor_neurons,\
                'If use_pq_linear is set to True, p_size + q_size must equal to motor_neurons'

        self._cnn = self._build_cnn(conv_filters, dropout=dropout)
        wiring = kncp.wirings.NCP(
            inter_neurons=inter_neurons,
            command_neurons=command_neurons,
            motor_neurons=motor_neurons,
            sensory_fanout=sensory_fanout,
            inter_fanout=inter_fanout,
            recurrent_command_synapses=recurrent_command_synapses,
            motor_fanin=motor_fanin,
        )

        self._rnn_cell = LTCCell(wiring, conv_filters[-1][1], ode_unfolds =3)
        if use_pq_linear:
            lnn_size = np.int(32 / 2)
            self._linear_q1 = nn.Linear(motor_neurons, lnn_size)
            self._linear_p1 = nn.Linear(motor_neurons, lnn_size)
            self._linear_q2 = nn.Linear(lnn_size, lnn_size)
            self._linear_p2 = nn.Linear(lnn_size, lnn_size)
            self._linear_q = nn.Linear(lnn_size, q_size)
            self._linear_p = nn.Linear(lnn_size, p_size)

        self._use_pq_linear = use_pq_linear
        self._p_size = p_size
        self._q_size = q_size

    def forward(self, x, state, obstacle, sgn, state_ltc=None):
        nBatch = x.size(0)

        sequence_size = x.size(1)
        channel = x.size(2)
        H = x.size(3)
        W = x.size(4)
        # CNN feature extractor
        x = x.view(nBatch * sequence_size, channel, H, W)
        z = self._cnn(x)
        z = z.max(dim=-1)[0].max(
            dim=-1)[0]  # maxpool over spatial dimension H', W'
        z = z.view(nBatch, sequence_size, -1)

        # LTC
        state_ltc = state_ltc[0]  # hidden state only no cell state
        out_ltc = []
        for t in range(
                z.shape[1]
        ):  # step-by-step inference with rnn cell (no-return-seq-state)
            z_t = z[:, t]
            out_ltc_t, state_ltc = self._rnn_cell(z_t, state_ltc)
            out_ltc.append(out_ltc_t)
        z = torch.stack(out_ltc, dim=1)
        state_ltc = [state_ltc]  # back to list for consistent format

        # Output
        z = z.reshape(nBatch * sequence_size, -1)
        if self._use_pq_linear:
            q = self._linear_q1(z)
            q = torch.tanh(q)
            q = self._linear_q2(q)
            q = torch.tanh(q)
            q = self._linear_q(q)

            p = self._linear_p1(z)
            p = torch.tanh(p)
            p = self._linear_p2(p)
            p = torch.tanh(p)
            p = self._linear_p(p)
            p = 4 * nn.Sigmoid()(p)
        else:
            p = z[:, :self._p_size]
            q = z[:, self._p_size:]
            p = 4 * nn.Sigmoid()(p)

        if sgn == 0 or sgn == 1:
            q = q.view(nBatch, sequence_size, -1)
            q = q[:, -1]  # only keep the dat from the last step
            p = p.view(nBatch, sequence_size, -1)
            p = p[:, -1]  # only keep the dat from the last step
            state = state[:, -1]  # only keep the dat from the last step
            obstacle = obstacle[:, -1]  # only keep the dat from the last step
        else:
            nBatch = nBatch * sequence_size
            state = state.view(nBatch, -1)
            obstacle = obstacle.view(nBatch, -1)

        # Set up the cost of the neuron of BarrierNet
        Q = Variable(torch.eye(2))
        Q = Q.unsqueeze(0).expand(nBatch, 2, 2).to(self.device)

        #softened HOCBFs
        #vehicle state
        s = state[:, 0]
        d = state[:, 1]
        mu = state[:, 2]
        v = state[:, 3]
        delta = state[:, 4]
        kappa = state[:, 5]

        lrf, lr = 0.5, 2  #lr/(lr+lf)
        beta = torch.atan(lrf * torch.tan(delta))
        cos_mu_beta = torch.cos(mu + beta)
        sin_mu_beta = torch.sin(mu + beta)
        mu_dot = v / lr * torch.sin(beta) - kappa * v * cos_mu_beta / (1 - d * kappa)

        #obstacle state
        obs_s = obstacle[:, 0]
        obs_d = obstacle[:, 1]

        barrier = (s - obs_s)**2 + (d - obs_d)**2 - 7.9**2  #radius of the obstacle-covering disk is 7.9 < 8m (mpc), avoiding the set boundary
        barrier_dot = 2 * (s - obs_s) * v * cos_mu_beta / (1 - d * kappa) + 2 * (d - obs_d) * v * sin_mu_beta
        Lf2b = 2 * (v * cos_mu_beta / (1 - d * kappa))**2 + 2 * (v * sin_mu_beta)**2 - 2 * (s - obs_s) * v * sin_mu_beta * mu_dot / (1 - d * kappa
            ) + 2 * (s - obs_s) * kappa * v**2 * sin_mu_beta * cos_mu_beta / (1 - d * kappa)**2 + 2 * (d - obs_d) * v * cos_mu_beta * mu_dot
        LgLfbu1 = 2 * (s - obs_s) * cos_mu_beta / (1 - d * kappa) + 2 * (d - obs_d) * sin_mu_beta
        LgLfbu2 = (-2 * (s - obs_s) * v * sin_mu_beta / (1 - d * kappa) + 2 *
                   (d - obs_d) * v * cos_mu_beta) * lrf / torch.cos(delta)**2 / (1 + (lrf * torch.tan(delta))**2)

        LgLfbu1 = torch.reshape(LgLfbu1, (nBatch, 1))
        LgLfbu2 = torch.reshape(LgLfbu2, (nBatch, 1))

        G = torch.cat([-LgLfbu1, -LgLfbu2], dim=1)
        G = torch.reshape(G, (nBatch, 1, 2)).to(self.device)
        h = (torch.reshape(
            (Lf2b +(p[:, 0] + p[:, 1]) * barrier_dot + p[:, 0] * p[:, 1] * barrier),
            (nBatch, 1))).to(self.device)

        e = Variable(torch.Tensor()).to(self.device)  #no equality constraints

        if self.training or sgn == 1 or sgn == 3:
            x = QPFunction(verbose=-1, solver=QPSolvers.PDIPM_BATCHED)(Q, q, G, h, e, e) #verbose = 0 to check accuracy
        else:
            self.p1 = p[0, 0]
            self.p2 = p[0, 1]
            x = solver(Q[0].double(), q[0].double(), G[0].double(), h[0].double())

        return x, state_ltc

    def get_initial_state(self, bsize):  #output <-> hidden state
        device = next(self._rnn_cell.parameters()).device
        return [
            torch.zeros(
                (
                    bsize,  # batch size
                    self._rnn_cell.state_size,  # hidden state size
                ),
                dtype=torch.float32).to(device),
        ]

    def _build_cnn(self, filters, dropout=0., no_act_last_layer=False):
        modules = nn.ModuleList()
        for i, filt in enumerate(filters):
            modules.append(nn.Conv2d(*filt))
            if (i != len(filters) - 1) or (not no_act_last_layer):
                modules.append(nn.BatchNorm2d(filt[1]))
                modules.append(nn.ReLU())
                if dropout > 0:
                    modules.append(nn.Dropout(p=dropout))
        modules = nn.Sequential(*modules)

        return modules


class CNN_LSTM_BN(nn.Module):
    def __init__(self,
                 conv_filters,
                 lstm_size,
                 q_size = 4,
                 p_size = 2,
                 w_size = 4,
                 dropout= 0., 
                 dev = 'cpu'):
        super(CNN_LSTM_BN, self).__init__()

        self.device = dev
        self.p1 = 0
        self.p2 = 0
        
        lnn_size = np.int(lstm_size/2)
        self._cnn = self._build_cnn(conv_filters, dropout=dropout)
        self._lstm = nn.LSTM(conv_filters[-1][1] + 2, lstm_size, batch_first=True) #input_size, hidden_size
        self._linear_q1 = nn.Linear(lstm_size, lnn_size)
        self._linear_p1 = nn.Linear(lstm_size, lnn_size)
        self._linear_q2 = nn.Linear(lnn_size, lnn_size)
        self._linear_p2 = nn.Linear(lnn_size, lnn_size)
        self._linear_q = nn.Linear(lnn_size, q_size)
        self._linear_p = nn.Linear(lnn_size, p_size)

        self._linear_w1 = nn.Linear(lstm_size, lnn_size)
        self._linear_w2 = nn.Linear(lnn_size, lnn_size)
        self._linear_w = nn.Linear(lnn_size, w_size)
        
        
        # QP params.
        # from previous layers 

    def forward(self, x, state, coe, Tt_d, sgn, state_lstm = None):
        nBatch = x.size(0)
        
        sequence_size = x.size(1)
        channel = x.size(2)
        H = x.size(3)
        W = x.size(4)
        # CNN feature extractor
        x = x.view(nBatch*sequence_size, channel, H, W)
        z = self._cnn(x)
        z = z.max(dim=-1)[0].max(dim=-1)[0]  # maxpool over spatial dimension H', W'
        z = z.view(nBatch, sequence_size, -1)
        
        # LSTM
        z = torch.cat([z, state[:,:,3:5]], dim = 2)
        z, state_lstm = self._lstm(z, state_lstm)

        # Output
        z = z.reshape(nBatch*sequence_size, -1)
        q = self._linear_q1(z)
        q = torch.tanh(q)
        q = self._linear_q2(q)
        q = torch.tanh(q)  
        q = self._linear_q(q)
        q = 2*(0.5 - nn.Sigmoid()(q))
        
        p = self._linear_p1(z)
        p = torch.tanh(p)
        p = self._linear_p2(p)
        p = torch.tanh(p)
        p = self._linear_p(p)
        p[:,0] = 10000*nn.Sigmoid()(p[:,0])
        p[:,1] = 10*nn.Sigmoid()(p[:,1])

        w = self._linear_w1(z)
        w = torch.tanh(w)
        w = self._linear_w2(w)
        w = torch.tanh(w)
        w = self._linear_w(w)
        w = 10*nn.Sigmoid()(w)
        
        if sgn == 0 or sgn == 1: 
            q = q.view(nBatch, sequence_size, -1)
            q = q[:,-1] 
            p = p.view(nBatch, sequence_size, -1)
            p = p[:,-1]
            w = w.view(nBatch, sequence_size, -1)
            w = w[:,-1]  
            state = state[:,-1]
            coe = coe[:,-1]
            Tt_d = Tt_d[:,-1]
        else:
            nBatch = nBatch*sequence_size
            state = state.view(nBatch, -1)
            coe = coe.view(nBatch, -1)
            Tt_d = Tt_d.view(nBatch, -1)



        # Set up the cost of the neuron of BarrierNet
        Q = Variable(torch.eye(5))
        Q = Q.unsqueeze(0).expand(nBatch, 5, 5).to(self.device)
        Q[:,4,4] = p[:,0]*Q[:,4,4]
        
        #softened CLF
        #vehicle state
        Tp = state[:,0]
        Tr = state[:,1]
        Ty = state[:,2]
        Tt = state[:,3]

        Cp = coe[:,0]
        Cr = coe[:,1]
        Cy = coe[:,2]
        Ct = coe[:,3]

        V = w[:,0]*Tp**2 + w[:,1]*Tr**2 + w[:,2]*Ty**2 + w[:,3]*(Tt - Tt_d)**2
        LfV = 0
        LgVup = (2*w[:,0]*Tp*Cp)[:,0] + (2*w[:,1]*Tr*Cr)[:,0] + (2*w[:,2]*Ty*Cy)[:,0]
        LgVur = (2*w[:,0]*Tp*Cp)[:,1] + (2*w[:,1]*Tr*Cr)[:,1] + (2*w[:,2]*Ty*Cy)[:,1]
        LgVuy = (2*w[:,0]*Tp*Cp)[:,2] + (2*w[:,1]*Tr*Cr)[:,2] + (2*w[:,2]*Ty*Cy)[:,2]
        LgVut = 2*w[:,3]*Tt*Ct
        
        LgVup = torch.reshape(LgVup, (nBatch, 1))
        LgVur = torch.reshape(LgVur, (nBatch, 1))
        LgVuy = torch.reshape(LgVuy, (nBatch, 1))
        LgVut = torch.reshape(LgVut, (nBatch, 1))
        LgVud = -torch.ones_like(LgVut)

        
        G = torch.cat([LgVup, LgVur, LgVuy, LgVut, LgVud], dim=1)
        G = torch.reshape(G, (nBatch, 1, 5)).to(self.device)
        h = (torch.reshape((-LfV - p[:,1]*V), (nBatch, 1))).to(self.device)
        # import pdb; pdb.set_trace()

        h1 = torch.ones_like(h)
        G1 = torch.zeros_like(G)
        G1[:,:,0] = torch.ones(nBatch, 1, 1)
        G10 = torch.zeros_like(G)
        G10[:,:,0] = -torch.ones(nBatch, 1, 1)

        G2 = torch.zeros_like(G)
        G2[:,:,1] = torch.ones(nBatch, 1, 1)
        G20 = torch.zeros_like(G)
        G20[:,:,1] = -torch.ones(nBatch, 1, 1)

        G3 = torch.zeros_like(G)
        G3[:,:,2] = torch.ones(nBatch, 1, 1)
        G30 = torch.zeros_like(G)
        G30[:,:,2] = -torch.ones(nBatch, 1, 1)

        G4 = torch.zeros_like(G)
        G4[:,:,3] = torch.ones(nBatch, 1, 1)
        G40 = torch.zeros_like(G)
        G40[:,:,3] = -torch.ones(nBatch, 1, 1)

        G = torch.cat([G,G1,G10,G2,G20,G3,G30,G4,G40], dim = 1).to(self.device)
        h = torch.cat([h,h1,h1, h1,h1, h1,h1, h1,h1], dim = 1).to(self.device)


        e = Variable(torch.Tensor()).to(self.device) #no equality constraints
        
        qd = torch.zeros_like(h1)
        q = torch.cat([q, qd], dim=1)
        #print("Shape of barrier_dot: ", barrier_dot.shape, barrier_dot.dtype)
        #exit()
            
        if self.training or sgn == 1 or sgn == 2 or sgn == 20 or sgn == 200 or sgn == 3:  
            x = QPFunction(verbose=-1, solver = QPSolvers.PDIPM_BATCHED)(Q, q, G, h, e, e)
        else:
            self.p1 = p[0,0]
            self.p2 = p[0,1]
            x = solver(Q[0].double(), q[0].double(), G[0].double(), h[0].double())
            x = np.array([[x[0], x[1]]])
            x = torch.tensor(x).float().to(self.device)
        
        
        return x, state_lstm
        
    def get_initial_state(self, bsize):  #output <-> hidden state
        device = next(self._lstm.parameters()).device
        return [
            torch.zeros((
                1,  #num of layers
                bsize, # batch size
                self._lstm.hidden_size, # hidden state size
            ),
                        dtype=torch.float32).to(device),
            torch.zeros((
                1,
                bsize,
                self._lstm.hidden_size, # cell state size
            ),
                        dtype=torch.float32).to(device)
        ]

    def _build_cnn(self, filters, dropout=0., no_act_last_layer=False):
        modules = nn.ModuleList()
        for i, filt in enumerate(filters):
            modules.append(nn.Conv2d(*filt))
            if (i != len(filters) - 1) or (not no_act_last_layer):
                modules.append(nn.BatchNorm2d(filt[1]))
                modules.append(nn.ReLU())
                if dropout > 0:
                    modules.append(nn.Dropout(p=dropout))
        modules = nn.Sequential(*modules)

        return modules



        
        
        
