# -*- coding: utf-8 -*-
"""
Created on Mon Oct 25 13:46:22 2021
"""

import time
from timeit import default_timer as timer
import gurobipy as gp
from gurobipy import GRB
import numpy as np

# NEW SUCCINCT MVNN MIP


class GUROBI_MIP2_MVNN_LINEAR:

    def __init__(self,
                model):

        # MVNN PARAMETERS
        self.model = model  # MVNN TORCH MODELS
        self.y_variables = []



    def generate_mip(self,
                     course_timetable,
                     credit_units,
                     cu_max = 5,
                     timeLimit = None,
                     MIPGap = None,
                     verbose = False,
                     ):

        self.mip = gp.Model("LINEAR MVNN MIP2")

        # Add IntFeasTol, primal feasibility
        if timeLimit:
            self.mip.Params.timeLimit = timeLimit
        if MIPGap:
            self.mip.Params.MIPGap = MIPGap


        # --- Variable declaration -----
        self.y_variables.append(self.mip.addVars([i for i in range(len(credit_units))], name="x_", vtype = GRB.BINARY))  # the "input variables, i.e. the first y level"



        layer = self.model.output_layer

        self.y_variables.append(self.mip.addVars([j for j in range(len(layer.weight.data))], name='y_output_', vtype = GRB.CONTINUOUS, lb = 0))


        output_weight = self.model.output_layer.weight.data[0]
        if (self.model.output_layer.bias is not None):
            output_bias = self.model.output_layer.bias.data
        else:
            output_bias = 0

        # # NOTE: OLD FROM ORIGINAL MIP
        # # Linear Constraints for the output layer
        # self.mip.addConstr(gp.quicksum(output_weight[k] * self.y_variables[-2][k] for k in range(len(output_weight))) + output_bias == self.y_variables[-1][0], name='output_layer')




        # # Budget constraint: not needed because everytime we update the budget constraint
        # self.mip.addConstr(gp.quicksum(self.y_variables[0][i] * course_prices[i] for i in range(len(course_prices))) <= budget, name = 'budget')

        # add the credit units constraint:
        self.mip.addConstr(gp.quicksum(self.y_variables[0][i] * credit_units[i] for i in range(len(credit_units))) <= cu_max, name = 'credit_units')

        # Overlapping courses constraint:
        for day in course_timetable:
            for timeslot in day:
                self.mip.addConstr(gp.quicksum(self.y_variables[0][i] for i in timeslot) <= 1, name='overlaps')
                # for any timeslot of any day, the student can only have one of the courses with a lecture on that timeslot

        # --- Objective Declaration ---
        self.mip.setObjective(self.y_variables[-1][0], GRB.MAXIMIZE)

        # NEW OBJECTIVE DECLARATION
        # W = model.layers[0].weight.data.cpu().detach().numpy()
        W = self.model.layers[0].weight.data[0]
        self.mip.setObjective(gp.quicksum(self.y_variables[0][i] * W[i] * output_weight[0] for i in range(len(self.y_variables[0]))) + output_bias, GRB.MAXIMIZE)

        if (verbose):
            self.mip.write('LINEAR MVNN_mip2_'+'_'.join(time.ctime().replace(':', '-').split(' '))+'.lp')

    def add_budget_constraint(self, course_prices, budget):
        try:
            c = self.mip.getConstrByName('budget')
            self.mip.remove(c)
            self.mip.update()
        except:
            pass
            # print('no budget variable')

            # budget constraint
        self.mip.addConstr(gp.quicksum(self.y_variables[0][i] * course_prices[i] for i in range(len(course_prices))) <= budget, name = 'budget')
        return

    def add_forbidden_bundle(self, bundle):
        expr_list = [self.y_variables[0][m] if bundle[m] == 1 else (1 - self.y_variables[0][m]) for m in range(len(bundle))]
        self.mip.addConstr(gp.quicksum(expr_list) <= len(bundle) - 0.1, name='alreadyQueried')
        self.mip.update()
        return

    def solve_mip(self,
                  outputFlag = False,
                  verbose = True
                  ):

        self.start = timer()
        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        self.end = timer()

        self.optimal_schedule = []
        # TODO: test try-catch for non-feasible solution
        try:
            for i in range(len(self.y_variables[0])):
                if(self.y_variables[0][i].x >= 0.99):
                    self.optimal_schedule.append(1)
                else:
                    self.optimal_schedule.append(0)
        except:
            self._print_info()
            raise ValueError('MIP did not solve succesfully!')

        if verbose:
            self._print_info()

        return self.optimal_schedule

    def solve_mip_rv(self,
                  outputFlag = False,
                  verbose = True
                     ):

        self.start = timer()
        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        self.end = timer()

        self.optimal_schedule = []
        # TODO: test try-catch for non-feasible solution
        try:
            for i in range(len(self.y_variables[0])):
                if(self.y_variables[0][i].x >= 0.99):
                    self.optimal_schedule.append(1)
                else:
                    self.optimal_schedule.append(0)
        except:
            self._print_info()
            raise ValueError('MIP did not solve succesfully!')

        if verbose:
            self._print_info()

        return self.optimal_schedule, self.mip.getObjective().getValue()

    def _print_info(self):
        print('\n')
        print(*['*']*30)
        print()
        print('MIP INFO:')
        print(*['-']*30)
        print(f'Name: {self.mip.ModelName}')
        print(f'Goal: {self._model_sense_converter(self.mip.ModelSense)}')
        print(f'Objective: {self.mip.getObjective()}')
        print(f'Number of variables: {self.mip.NumVars}')
        print(f' - Binary {self.mip.NumBinVars}')
        print(f'Number of linear constraints: {self.mip.NumConstrs}')
        print(f'Primal feasibility tolerance for constraints: {self.mip.Params.FeasibilityTol}')
        print(f'Integer feasibility tolerance: {self.mip.Params.IntFeasTol}')
        print(f'Relative optimality gap: {self.mip.Params.MIPGap}')
        print(f'Time Limit: {self.mip.Params.TimeLimit}')
        print('')
        print('MIP SOLUTION:')
        print(*['-']*30)
        print(f'Status: {self._status_converter(self.mip.status)}')
        print(f'Elapsed in sec: {self.end - self.start}')
        print(f'Reached Relative optimality gap: {self.mip.MIPGap}')
        print(f'Optimal Allocation: {self.optimal_schedule}')
        print(f'Objective Value: {self.mip.ObjVal}')
        print(f'Number of stored solutions: {self.mip.SolCount}')
        print('IA Case Statistics:')


    def _status_converter(self, int_status):
        status_table = ['woopsies!', 'LOADED', 'OPTIMAL', 'INFEASIBLE', 'INF_OR_UNBD', 'UNBOUNDED', 'CUTOFF', 'ITERATION_LIMIT', 'NODE_LIMIT', 'TIME_LIMIT', 'SOLUTION_LIMIT', 'INTERRUPTED', 'NUMERIC', 'SUBOPTIMAL', 'INPROGRESS', 'USER_OBJ_LIMIT']
        return status_table[int_status]

    def _model_sense_converter(self, int_sense):
        if int_sense == 1:
            return 'Minimize'
        elif int_sense == -1:
            return 'Maximize'
        else:
            raise ValueError('int_sense needs to be -1:maximize or 1: minimize')
