import os
import torch
import time
import numpy as np
import scipy.stats as stats

import gurobipy as gp
from gurobipy import GRB

import torch.optim as optim

class PlannerMIP(object):

    def __init__(self, model):

        self.model = model

    def trajectory_optimization(self,
                                state_cur,  # current state, [n_his, state_dim]
                                obs_goal,   # goal, [state_dim]
                                act_seq,    # initial action sequence, [-1, action_dim]
                                n_look_ahead,   # number of look ahead steps/horizon
                               ):

        state_goal = obs_goal

        N = n_look_ahead

        params = list(self.model.model.parameters())
        params = [p.data.cpu().numpy() for p in params]

        lb = self.model.lb
        ub = self.model.ub

        n_his = state_cur.shape[0]
        state_dim = state_cur.shape[-1]
        action_dim = act_seq.shape[-1]

        act_his = act_seq[:n_his-1]

        n_relu = len(self.model.mask)
        mask = self.model.mask
        mask = [p.data.cpu().numpy() for p in mask]

        ### build up the MIP

        m = gp.Model("mip")
        m.Params.TimeLimit = 20

        # state variable
        x = m.addMVar(shape=(n_his + N) * state_dim, vtype=GRB.CONTINUOUS,
                      lb=-GRB.INFINITY, ub=GRB.INFINITY, name='x')

        # action variable
        # use different bounds for different dimensions
        a0 = m.addMVar(shape=(N), vtype=GRB.CONTINUOUS, lb=0.0, ub=1.0, name='a0')
        # a1 = m.addMVar(shape=(N), vtype=GRB.CONTINUOUS, lb=-0.1, ub=0.1, name='a1')
        a = m.addMVar(shape=(n_his + N - 1) * action_dim, vtype=GRB.CONTINUOUS,
                      lb=-GRB.INFINITY, ub=GRB.INFINITY, name='a')

        varDict = {'x': x, 'a0': a0, 'a':a} #'a1': a1, 'a': a}


        m.addConstr(a[(n_his - 1) * action_dim::1] == a0)

        # binary variables
        b = []

        # constain the initial condition
        print('state_cur', state_cur)
        m.addConstr(x[:n_his * state_dim] == state_cur.reshape(-1))
        m.addConstr(a[:(n_his - 1) * action_dim] == act_his.reshape(-1))

        for t in range(N):

            b.append([])

            # current state and action
            x_cur = x[t * state_dim:(t + n_his) * state_dim]
            a_cur = a[t * action_dim:(t + n_his) * action_dim]

            name = 's_%d_%d' % (t, 0)
            s = m.addMVar(shape=x_cur.shape[0] + a_cur.shape[0], vtype=GRB.CONTINUOUS,
                          lb=-GRB.INFINITY, ub=GRB.INFINITY, name=name)
            varDict[name] = s

            m.addConstr(s[:x_cur.shape[0]] == x_cur)
            m.addConstr(s[x_cur.shape[0]:] == a_cur)

            # passing through the neural network layer by layer
            for idx_layer in range(n_relu):
                # print('[%d/%d][%d/%d]' % (t, N, idx_layer, n_relu))
                mask_pos = mask[idx_layer][0] == 2
                mask_neg = mask[idx_layer][0] == 0
                mask_others = mask[idx_layer][0] == 1

                # activation before relu
                ss = params[idx_layer * 2] @ s + params[idx_layer * 2 + 1]
                n_neuron = ss.shape[0]

                # activation after relu and mask
                name = 'y_%d_%d' % (t, idx_layer + 1)
                y = m.addMVar(shape=n_neuron, vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, ub=GRB.INFINITY, name=name)
                varDict[name] = y

                m.addConstr(y == ss)

                name = 's_%d_%d' % (t, idx_layer + 1)
                s = m.addMVar(shape=n_neuron, vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, ub=GRB.INFINITY, name=name)
                varDict[name] = s

                # positive/identity neuron
                m.addConstr(s[mask_pos] == y[mask_pos])
                # negative neuron
                m.addConstr(s[mask_neg] == 0.)


                name = 'b_%d_%d' % (t, idx_layer + 1)
                b[t].append(m.addMVar(shape=n_neuron, vtype=GRB.BINARY, name=name))
                varDict[name] = b[t][-1]

                # b[t].append(m.addMVar(shape=n_neuron, vtype=GRB.CONTINUOUS, lb=0., ub=1.))


                for idx_neuron in range(n_neuron):
                    if mask_others[idx_neuron] == 1:
                        # regular relu
                        lb_cur = lb[idx_layer][idx_neuron]
                        ub_cur = ub[idx_layer][idx_neuron]
                        b_cur = b[t][idx_layer][idx_neuron]
                        m.addConstr(s[idx_neuron] <= y[idx_neuron] - lb_cur * (1 - b_cur))
                        m.addConstr(s[idx_neuron] >= y[idx_neuron])
                        m.addConstr(s[idx_neuron] <= ub_cur * b_cur)
                        m.addConstr(s[idx_neuron] >= 0.)


            # calculate the residual
            s = params[8] @ s + params[9]

            # calculate the next state
            xx = x[(t + n_his - 1) * state_dim:(t + n_his) * state_dim]
            s = s + xx

            m.addConstr(s == x[(t + n_his) * state_dim:(t + n_his + 1) * state_dim])


        # print('state_goal', state_goal)
        name = 'residual'
        residual = m.addMVar(shape=state_goal.shape[0], vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, ub=GRB.INFINITY, name=name)
        varDict[name] = residual
        m.addConstr(residual == x[-state_dim:] - state_goal)
        m.setObjective(residual @ residual, GRB.MINIMIZE)

        # m.setObjective(0, GRB.MAXIMIZE)

        # st_time = time.time()
        m.optimize()
        # print('time', time.time() - st_time)

        # print('Obj: %g' % m.objVal)

        a_sol = a.X.reshape(n_his + N - 1, action_dim)
        action_seq_future = a_sol[n_his - 1:]

        x_sol = x.X.reshape(n_his + N, state_dim)
        obs_seq_best = x_sol[n_his:]

        reward_best = m.objVal

        varSol = {}
        for key in varDict.keys():
            varSol[key] = varDict[key].X

        return {'action_sequence': action_seq_future,   # [n_roll, action_dim]
                'observation_sequence': obs_seq_best,   # [n_roll, obs_dim]
                'reward': reward_best,
                'varSol': varSol}
