import json
import torch.nn as nn
from torch.autograd import Variable
import torch
from .base import LitModel as Base
from .utils import build_cnn, build_mlp, cvx_solver
from qpth.qp import QPFunction, QPSolvers
import numpy as np


class LitModel(Base):
    @staticmethod
    def add_model_specific_args(parent_parser):
        parent_parser = Base.add_model_specific_args(parent_parser)
        return parent_parser

    def custom_setup(self):
        lstm_size = 64
        p_size = 2
        q_size = 2
        conv_filters = [[3, 24, 5, 2, 2], [24, 36, 5, 2, 2], [36, 48, 3, 2, 1],
                        [48, 64, 3, 1, 1], [64, 64, 3, 1, 1]]
        dropout = 0.3

        lnn_size = np.int(lstm_size / 2)
        self._cnn = self._build_cnn(conv_filters, dropout=dropout)
        self._lstm = nn.LSTM(conv_filters[-1][1], lstm_size,
                             batch_first=True)  #input_size, hidden_size
        self._linear_q1 = nn.Linear(lstm_size, lnn_size)
        self._linear_p1 = nn.Linear(lstm_size, lnn_size)
        self._linear_q2 = nn.Linear(lnn_size, lnn_size)
        self._linear_p2 = nn.Linear(lnn_size, lnn_size)
        self._linear_q = nn.Linear(lnn_size, q_size)
        self._linear_p = nn.Linear(lnn_size, p_size)

    def forward(self,
                batch,
                rnn_state,
                solver='QP',
                store_intermediate_data=False,
                **kwargs):
        img_seq = batch[0]
        state_seq = batch[1]
        obs_seq = batch[2]
        ctrl_seq = batch[3]

        sgn = 4
        x = img_seq
        state_lstm = rnn_state
        state = state_seq
        obstacle = obs_seq
        y = ctrl_seq

        nBatch = x.size(0)

        sequence_size = x.size(1)
        channel = x.size(2)
        H = x.size(3)
        W = x.size(4)
        # CNN feature extractor
        x = x.view(nBatch * sequence_size, channel, H, W)
        z = self._cnn(x)
        z = z.max(dim=-1)[0].max(
            dim=-1)[0]  # maxpool over spatial dimension H', W'
        z = z.view(nBatch, sequence_size, -1)

        # LSTM
        z, state_lstm = self._lstm(z, state_lstm)

        # Output
        z = z.reshape(nBatch * sequence_size, -1)
        q = self._linear_q1(z)
        q = torch.tanh(q)
        q = self._linear_q2(q)
        q = torch.tanh(q)
        q = self._linear_q(q)

        p = self._linear_p1(z)
        p = torch.tanh(p)
        p = self._linear_p2(p)
        p = torch.tanh(p)
        p = self._linear_p(p)
        p = 4 * nn.Sigmoid()(p)

        if sgn == 0 or sgn == 1 or solver == 'cvxpy':
            q = q.view(nBatch, sequence_size, -1)
            q = q[:, -1]  # only keep the dat from the last step
            p = p.view(nBatch, sequence_size, -1)
            p = p[:, -1]  # only keep the dat from the last step
            state = state[:, -1]  # only keep the dat from the last step
            obstacle = obstacle[:, -1]  # only keep the dat from the last step
            y = y[:, -1]
        else:
            nBatch = nBatch * sequence_size
            state = state.view(nBatch, -1)
            obstacle = obstacle.view(nBatch, -1)
            y = y.view(nBatch, -1)

        v_delta = q
        if sgn == 0 or solver == 'cvxpy':
            q = -((q - state[:, 3:5]) / 0.1)  #.detach()
        else:
            q = -y[:, 2:4]

        # Set up the cost of the neuron of BarrierNet
        Q = Variable(torch.eye(2))
        Q = Q.unsqueeze(0).expand(nBatch, 2, 2).to(self.device)

        #softened HOCBFs
        #vehicle state
        s = state[:, 0]
        d = state[:, 1]
        mu = state[:, 2]
        v = state[:, 3]
        delta = state[:, 4]
        kappa = state[:, 5]

        lrf, lr = 0.5, 2  #lr/(lr+lf)
        beta = torch.atan(lrf * torch.tan(delta))
        cos_mu_beta = torch.cos(mu + beta)
        sin_mu_beta = torch.sin(mu + beta)
        mu_dot = v / lr * torch.sin(beta) - kappa * v * cos_mu_beta / (
            1 - d * kappa)

        #obstacle state
        obs_s = obstacle[:, 0]
        obs_d = obstacle[:, 1]

        barrier = (s - obs_s)**2 + (
            d - obs_d
        )**2 - 7.9**2  #radius of the obstacle-covering disk is 7.9 < 8m (mpc), avoiding the set boundary
        barrier_dot = 2 * (s - obs_s) * v * cos_mu_beta / (
            1 - d * kappa) + 2 * (d - obs_d) * v * sin_mu_beta
        Lf2b = 2 * (v * cos_mu_beta / (1 - d * kappa))**2 + 2 * (
            v *
            sin_mu_beta)**2 - 2 * (s - obs_s) * v * sin_mu_beta * mu_dot / (
                1 - d * kappa
            ) + 2 * (s - obs_s) * kappa * v**2 * sin_mu_beta * cos_mu_beta / (
                1 - d * kappa)**2 + 2 * (d - obs_d) * v * cos_mu_beta * mu_dot
        LgLfbu1 = 2 * (s - obs_s) * cos_mu_beta / (1 - d * kappa) + 2 * (
            d - obs_d) * sin_mu_beta
        LgLfbu2 = (-2 * (s - obs_s) * v * sin_mu_beta / (1 - d * kappa) + 2 *
                   (d - obs_d) * v * cos_mu_beta) * lrf / torch.cos(
                       delta)**2 / (1 + (lrf * torch.tan(delta))**2)

        LgLfbu1 = torch.reshape(LgLfbu1, (nBatch, 1))
        LgLfbu2 = torch.reshape(LgLfbu2, (nBatch, 1))

        G = torch.cat([-LgLfbu1, -LgLfbu2], dim=1)
        G = torch.reshape(G, (nBatch, 1, 2)).to(self.device)
        h = (torch.reshape(
            (Lf2b +
             (p[:, 0] + p[:, 1]) * barrier_dot + p[:, 0] * p[:, 1] * barrier),
            (nBatch, 1))).to(self.device)
        # import pdb; pdb.set_trace()

        e = Variable(torch.Tensor()).to(self.device)  #no equality constraints

        #print("Shape of barrier_dot: ", barrier_dot.shape, barrier_dot.dtype)
        #exit()

        # if self.training or sgn == 1 or sgn == 2 or sgn == 20 or sgn == 200 or sgn == 2000 or sgn == 3 or sgn == 30 or sgn == 300 or sgn == 3000:

        if sgn == 0 or solver == 'cvxpy':
            self.p1 = p[0, 0]
            self.p2 = p[0, 1]
            x = cvx_solver(Q[0].double(), q[0].double(), G[0].double(),
                       h[0].double())
            x = np.array([[x[0], x[1]]])
            x = torch.tensor(x).float().to(self.device)
        else:
            x = QPFunction(verbose=-1, solver=QPSolvers.PDIPM_BATCHED)(Q, q, G,
                                                                       h, e, e)
        out = torch.cat([v_delta, x], dim=1)  # return v, delta, a, omega

        if store_intermediate_data:  # NOTE: store as scalars
            self.intermediate_data = {
                'q0': q[0, 0].item(),
                'q1': q[0, 1].item(),
                'p0': p[0, 0].item(),
                'p1': p[0, 1].item(),
                's': s.item(),
                'd': d.item(),
                'mu': mu.item(),
                'v': v.item(),
                'delta': delta.item(),
                'kappa': kappa.item(),
                'obs_s': obs_s[0].item(),
                'obs_d': obs_d[0].item(),
                'barrier': barrier.item(),
                'Gz': (G * x).sum().item(),
                'h': h[0, 0].item(),
                'active': not torch.allclose(x, -q),
            }

        nBatch, sequence_size = img_seq.shape[:2]
        out = out.reshape((nBatch, sequence_size, -1))
        rnn_state = state_lstm

        return out, rnn_state

    def get_initial_state(self, batch_size):  #output <-> hidden state
        device = next(self._lstm.parameters()).device
        bsize = batch_size
        return [
            torch.zeros(
                (
                    1,  #num of layers
                    bsize,  # batch size
                    self._lstm.hidden_size,  # hidden state size
                ),
                dtype=torch.float32).to(device),
            torch.zeros(
                (
                    1,
                    bsize,
                    self._lstm.hidden_size,  # cell state size
                ),
                dtype=torch.float32).to(device)
        ]

    def _build_cnn(self, filters, dropout=0., no_act_last_layer=False):
        modules = nn.ModuleList()
        for i, filt in enumerate(filters):
            modules.append(nn.Conv2d(*filt))
            if (i != len(filters) - 1) or (not no_act_last_layer):
                modules.append(nn.BatchNorm2d(filt[1]))
                modules.append(nn.ReLU())
                if dropout > 0:
                    modules.append(nn.Dropout(p=dropout))
        modules = nn.Sequential(*modules)

        return modules