import os
import sys

from sklearn.svm import NuSVR
# from sklearn.svm import SVR


sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MIP'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN', 'src'))


# import math


import gurobipy as gp
import numpy as np
from gurobipy import GRB
from timeit import default_timer as timer
import sys
import os


# from MIP.util_dataset import create_multiple_datasets_enhanced
#
# from MIP.cleanup import timetable_generator

# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
#
# sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MIP'))
# sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN'))
# sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN', 'src'))

################################


class gurobi_MIP_SVR:
    def __init__(self, model, gamma):
        self.model = model
        if model.kernel == 'linear':
            self.gamma = 1
        else:
            self.gamma = gamma

    def _generate_rbf_kernel_helper_values(self, max_dif_items = 5):
        return np.exp(-self.gamma * np.arange(max_dif_items+1))

    def _get_indecies(self, k):
        current_support_vector = self.model.support_vectors_[k]
        one_index_list = []
        zero_index_list = []
        for i, x in enumerate(current_support_vector):
            if x == 1:
                one_index_list = one_index_list + [i]
            elif x == 0:
                zero_index_list = zero_index_list + [i]

        return one_index_list, zero_index_list

    def generate_mip(self, course_prices, credit_units, budget, course_timetable, cu_max = 5, verbose = False):
        """
        This mip is written based on Machine Learning-powered Iterative Combinatorial Auctions (Brero et al.)
        """
        self.a_variables = []
        self.z_variables = []
        self.objective = []
        self.mip = gp.Model("SVR GUROBI MIP")

        # number of items
        M = len(course_prices)
        max_diff_items = max(2*cu_max, M)

        self.a_variables = self.mip.addVars([j for j in range(M)], name = 'a_', vtype=GRB.BINARY)

        if self.model.kernel in ['linear', 'poly']:
            for k in range(int(self.model.n_support_)):
                self.objective.append(self.gamma*gp.quicksum(self.a_variables[j]*self.model.support_vectors_[k][j] for j in range(M))+self.model.coef0)
            if self.model.kernel == 'linear':
                self.mip.setObjective(gp.quicksum(self.objective[k] * self.model.dual_coef_[0][k] for k in
                                                  range(int(self.model.n_support_))), GRB.MAXIMIZE)
            elif self.model.kernel == 'poly' and self.model.degree == 2:
                self.mip.setObjective(gp.quicksum(self.objective[k]*self.objective[k]*self.model.dual_coef_[0][k] for k in range(int(self.model.n_support_))), GRB.MAXIMIZE)

        elif self.model.kernel == 'rbf':
            # Implementation based on the Machine Learning-powered Iterative Combinatorial Auctions (Brero et al.)
            for k in range(int(self.model.n_support_)):
                self.z_variables.append(
                    self.mip.addVars([tau for tau in range(max_diff_items + 1)], name="z_{}_".format(k), vtype=GRB.BINARY))
            rbf_helpers = self._generate_rbf_kernel_helper_values(max_dif_items=max_diff_items)
            for k in range(int(self.model.n_support_)):
                self.objective.append(gp.quicksum(rbf_helpers[tau] * self.z_variables[k][tau] for tau in range(max_diff_items+1)) + self.model.coef0)
            self.mip.setObjective(gp.quicksum(self.model.dual_coef_[0][k]*self.objective[k] for k in range(int(self.model.n_support_))), GRB.MAXIMIZE)

            for k in range(int(self.model.n_support_)):
                self.mip.addConstr(gp.quicksum(self.z_variables[k][tau] for tau in range(max_diff_items+1)) == 1, name='rbf_new_variable_{}'.format(k))

                one_index_list, zero_index_list = self._get_indecies(k = k)
                in_vec_not_in_a = gp.quicksum(1 - self.a_variables[i] for i in one_index_list)
                in_a_not_in_vec = gp.quicksum(self.a_variables[i] for i in zero_index_list)

                right_hand_condition = gp.quicksum((tau+1)*self.z_variables[k][tau] for tau in range(max_diff_items+1)) - 1

                self.mip.addConstr(in_vec_not_in_a + in_a_not_in_vec == right_hand_condition, name = 'ztk_force_{}'.format(k))

        self.mip.addConstr(gp.quicksum(self.a_variables[m] * course_prices[m] for m in range(M)) <= budget, name='budget')

        # add the credit units constraint:
        self.mip.addConstr(gp.quicksum(self.a_variables[m] * credit_units[m] for m in range(M)) <= cu_max, name='credit_units')

        # add the overlapping courses constraint:
        for day in course_timetable:
            for timeslot in day:
                self.mip.addConstr(gp.quicksum(self.a_variables[m] for m in timeslot) <= 1, name='overlaps')
                # for any timeslot of any day, the student can only have one of the courses with a lecture on that timeslot

        if (verbose):
            self.mip.write(os.path.expanduser('~/Desktop/SVR_mip.lp'))

        return self.mip

    def add_budget_constraint(self, course_prices, budget):
        try:
            c = self.mip.getConstrByName('budget')
            self.mip.remove(c)
            self.mip.update()
        except:
            pass
            # print('no budget variable')

        self.mip.addConstr(gp.quicksum(self.a_variables[m] * course_prices[m] for m in range(len(course_prices))) <= budget, name='budget')
        return

    def add_forbidden_bundle(self, bundle):
        self.mip.addConstr(gp.quicksum(self.a_variables[m] * bundle[m] for m in range(len(self.a_variables))) <= np.sum(bundle) - 0.1, name='alreadyQueried')
        self.mip.update()
        return

    def solve_mip(self, outputFlag = False, verbose = False):
        start = timer()

        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        end = timer()
        optimal_value = self.model.intercept_ + self.mip.getObjective().getValue()
        if verbose:
            print(f'SVR MIP solved in: {end - start}')
            print(f'The value of the optimal solution is: {optimal_value}')

        optimal_schedule = []
        for i in range(len(self.a_variables)):
            if (self.a_variables[i].x >= 0.99):
                optimal_schedule.append(1)
            else:
                optimal_schedule.append(0)
        return np.array(optimal_schedule), optimal_value

#  NOTE: commented out because not sure if this part is compatitable with the server directory structure, but works fine locally.

# if __name__ == '__main__':
#     n_courses = 30
#     timetable = timetable_generator(n_courses=n_courses,
#                                     credit_units=[1 for i in range(n_courses)])
#     data = create_multiple_datasets_enhanced(timetable=timetable,
#                                              num_datasets=10,
#                                              train_size=50,
#                                              test_size=5000,
#                                              validation_size=8,
#                                              complement_range=1,
#                                              number_of_courses=30)
#     (X_train_list, y_train_list, X_validation_list, y_validation_list, X_test_list, y_test_list) = data
#
#     times_to_generate = []
#     times_to_solve = []
#     for i in range(0, len(X_train_list)):
#         print(f'Student number: {i}')
#         X_train, y_train, X_val, y_val, X_test, y_test = X_train_list[i], y_train_list[i], X_validation_list[i], \
#                                                          y_validation_list[i], X_test_list[i], y_test_list[i]
#
#         model = NuSVR(kernel='rbf', degree = 2, gamma='auto')
#         model.fit(X_train, y_train)
#
#         if model.gamma == 'scale':
#             gamma = 1 / (n_courses * X_train.var())
#         elif model.gamma == 'auto':
#             gamma = 1 / n_courses
#
#         mip = gurobi_MIP_SVR(model, gamma)
#         # print(mip.parsed_tree)
#         start = timer()
#         mip.generate_mip(course_prices=np.repeat(0, n_courses), credit_units=np.repeat(1, n_courses), budget=5, cu_max=5, course_timetable=[[]])
#         mid = timer()
#         optimal_schedule, optimal_value = mip.solve_mip(verbose=True)
#         end = timer()
#         print("Difference between model prediction and mip output: {}".format(model.predict(optimal_schedule.reshape(1, -1)) - optimal_value))
#         print("time taken to generate the mip:{}".format(mid-start))
#         print("time taken to solve the mip:{}".format(end-mid))
#         times_to_generate.append(mid-start)
#         times_to_solve.append(end - mid)
#         y_pred = model.predict(X_test)
#         print("Number of items that have higher value: {}".format(sum(y_pred > optimal_value)))
#
#     print('---- Aggregated Results ---')
#     print(f'AVG time taken to generate those {len(times_to_generate)} MIPs: {np.mean(times_to_generate)}')
#     print(f'AVG time taken to solve those {len(times_to_solve)} MIPs: {np.mean(times_to_solve)}')
