# -*- coding: utf-8 -*-
"""
Created on Mon Oct 25 13:46:22 2021
"""

import time
from timeit import default_timer as timer
import gurobipy as gp
from gurobipy import GRB
import numpy as np
from pdb import set_trace

# NEW SUCCINCT MVNN MIP


class GUROBI_MIP_POLY_REGRESSION:

    def __init__(self,
                model):

        # MVNN PARAMETERS
        self.model = model  # MVNN TORCH MODELS
        self.y_variables = [] # the input variables that correspond to the pairwise interactions
        self.x_variables = [] # the input variables that correspond to the courses


    def generate_mip(self,
                     course_timetable,
                     credit_units,
                     course_prices,
                     budget,
                     cu_max = 5,
                     timeLimit = None,
                     MIPGap = None,
                     verbose = False,
                     ):

        self.mip = gp.Model("POLY REGRESSION MIP")
        self.coefficients = self.model.coef_.reshape(-1)  # coefficients of the linear regression model, flattened

        # Add IntFeasTol, primal feasibility
        if timeLimit:
            self.mip.Params.timeLimit = timeLimit
        if MIPGap:
            self.mip.Params.MIPGap = MIPGap


        # --- Variable declaration -----
        # the input variables that correspond to the courses
        self.x_variables = self.mip.addVars([i for i in range(len(credit_units))], name="x_", vtype = GRB.BINARY)  


        # the input variables that correspond to pairwise interactions
        for i in range(len(credit_units)):
            for j in range(i+1, len(credit_units)):
                self.y_variables.append(self.mip.addVar(name=f'y_{i}_{j}', vtype = GRB.BINARY))

       
        # set_trace()
        # print('Number of interaction variables: ', len(self.y_variables))
        # print('Number of course variables: ', len(self.x_variables))

        # -- Constraints declaration ---
        # Add Budget constraint:
        self.mip.addConstr(gp.quicksum(self.x_variables[i] * course_prices[i] for i in range(len(credit_units))) <= budget, name = 'budget')

        # add the credit units constraint:
        self.mip.addConstr(gp.quicksum(self.x_variables[i] * credit_units[i] for i in range(len(credit_units))) <= cu_max, name = 'credit_units')

        # Overlapping courses constraint:
        for day in course_timetable:
            for timeslot in day:
                self.mip.addConstr(gp.quicksum(self.x_variables[i] for i in timeslot) <= 1, name='overlaps')
                # for any timeslot of any day, the student can only have one of the courses with a lecture on that timeslot

        # add the pairwise interaction constraints
        constraint_counter = 0
        for i in range(len(credit_units)):
            for j in range(i+1, len(credit_units)):
                if self.coefficients[constraint_counter + len(credit_units)] > 0:  # we need the offset because the model coefficients are ordered as [course coefficients, interaction coefficients]
                    # if the coefficient is positive -> only allow the solver to turn on the interaction variable if both courses are taken
                    self.mip.addConstr(self.x_variables[i] + self.x_variables[j] >= 2 * self.y_variables[constraint_counter], name='pairwise_interaction_positive')
                elif self.coefficients[constraint_counter + len(credit_units)] < 0: 
                    # if the coefficient is negative -> force the solver to turn on the interaction variable if both courses are taken
                    self.mip.addConstr(self.x_variables[i] + self.x_variables[j] <= 1 + self.y_variables[constraint_counter], name='pairwise_interaction_negative')    
                constraint_counter += 1

        
        # --- Objective Declaration ---
        # self.mip.setObjective(gp.quicksum(self.y_variables[i] * self.model.coeff_[i] for i in range(len(self.model.coeff))), GRB.MAXIMIZE)
        self.mip.setObjective(gp.quicksum(self.x_variables[i] * self.coefficients[i] for i in range(len(credit_units))) + 
                              gp.quicksum(self.y_variables[i - len(credit_units)] * self.coefficients[i] for i in range(len(credit_units), len(self.coefficients))), GRB.MAXIMIZE)


        # --- MIP Optimization ---
        if (verbose):
            self.mip.write('POLY_REGR_mip2_'+'_'.join(time.ctime().replace(':', '-').split(' '))+'.lp')

    def add_budget_constraint(self, course_prices, budget):
        try:
            c = self.mip.getConstrByName('budget')
            self.mip.remove(c)
            self.mip.update()
        except:
            pass
            # print('no budget variable')

            # budget constraint
        self.mip.addConstr(gp.quicksum(self.x_variables[i] * course_prices[i] for i in range(len(course_prices))) <= budget, name = 'budget')
        return

    def solve_mip(self,
                  outputFlag = False,
                  verbose = True
                  ):

        self.start = timer()
        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        self.end = timer()

        self.optimal_schedule = []
        try:
            for i in range(len(self.x_variables)):
                # print(f'For course {i} the value is {self.x_variables[i].x}')
                if(self.x_variables[i].x >= 0.99):
                    self.optimal_schedule.append(1)
                else:
                    self.optimal_schedule.append(0)
        except:
            self._print_info()
            raise ValueError('MIP did not solve succesfully!')

        if verbose:
            self._print_info()

        return self.optimal_schedule

    def solve_mip_rv(self,
                  outputFlag = False,
                  verbose = True
                     ):

        self.start = timer()
        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        self.end = timer()

        self.optimal_schedule = []

        try:
            for i in range(len(self.x_variables)):
                if(self.x_variables[0].x >= 0.99):
                    self.optimal_schedule.append(1)
                else:
                    self.optimal_schedule.append(0)
        except:
            self._print_info()
            raise ValueError('MIP did not solve succesfully!')

        if verbose:
            self._print_info()

        return self.optimal_schedule, self.mip.getObjective().getValue()

    def _print_info(self):
        print('\n')
        print(*['*']*30)
        print()
        print('MIP INFO:')
        print(*['-']*30)
        print(f'Name: {self.mip.ModelName}')
        print(f'Goal: {self._model_sense_converter(self.mip.ModelSense)}')
        print(f'Objective: {self.mip.getObjective()}')
        print(f'Number of variables: {self.mip.NumVars}')
        print(f' - Binary {self.mip.NumBinVars}')
        print(f'Number of linear constraints: {self.mip.NumConstrs}')
        print(f'Primal feasibility tolerance for constraints: {self.mip.Params.FeasibilityTol}')
        print(f'Integer feasibility tolerance: {self.mip.Params.IntFeasTol}')
        print(f'Relative optimality gap: {self.mip.Params.MIPGap}')
        print(f'Time Limit: {self.mip.Params.TimeLimit}')
        print('')
        print('MIP SOLUTION:')
        print(*['-']*30)
        print(f'Status: {self._status_converter(self.mip.status)}')
        print(f'Elapsed in sec: {self.end - self.start}')
        print(f'Reached Relative optimality gap: {self.mip.MIPGap}')
        print(f'Optimal Allocation: {self.optimal_schedule}')
        print(f'Objective Value: {self.mip.ObjVal}')
        print(f'Number of stored solutions: {self.mip.SolCount}')


    def _status_converter(self, int_status):
        status_table = ['woopsies!', 'LOADED', 'OPTIMAL', 'INFEASIBLE', 'INF_OR_UNBD', 'UNBOUNDED', 'CUTOFF', 'ITERATION_LIMIT', 'NODE_LIMIT', 'TIME_LIMIT', 'SOLUTION_LIMIT', 'INTERRUPTED', 'NUMERIC', 'SUBOPTIMAL', 'INPROGRESS', 'USER_OBJ_LIMIT']
        return status_table[int_status]

    def _model_sense_converter(self, int_sense):
        if int_sense == 1:
            return 'Minimize'
        elif int_sense == -1:
            return 'Maximize'
        else:
            raise ValueError('int_sense needs to be -1:maximize or 1: minimize')
