# from MIP.util_dataset import create_multiple_datasets_enhanced
from timeit import default_timer as timer
from gurobipy import GRB
import xgboost
import gurobipy as gp
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import kendalltau
from sklearn.linear_model import LinearRegression, LassoCV
# from MIP.useful_functions import timetable_generator
import random

import numpy as np
import os
import sys

import pandas as pd
import scipy.stats


# In[2]:

import argparse

import random
import numpy as np
import os
import math
import sys
import pickle


from functools import partial

import scipy.stats
import torch
import itertools
from sklearn.metrics import accuracy_score
import torch.nn.functional as F


from sklearn.linear_model import LinearRegression, ElasticNetCV
from sklearn.svm import NuSVR
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import mean_squared_error, make_scorer
import xgboost as xgb
from sklearn.tree import DecisionTreeRegressor

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MIP'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN', 'src'))


##########


sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MIP'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'MVNN', 'src'))

####################


class Node:
    """
    Class Node used to recreate the xgboost subtrees
    """

    def __init__(self, index, left, right):
        self.index = index
        self.left = left
        self.right = right


def find_reachble_leaves(Tree, node):
    """
    The function finds the reachable leafs from the parameter node using BFS
    Tree: A list of nodes from class Node
    node: given node from class Node
    """
    visited = []
    queue = []

    queue.append(node)
    visited.append(node)
    while queue:
        s = queue.pop()
        neighbors = []
        if Tree[s].left != -1:
            neighbors = neighbors + [Tree[s].left]
        if Tree[s].right != -1:
            neighbors = neighbors + [Tree[s].right]
        for neighbor in neighbors:
            if neighbor not in visited:
                visited.append(neighbor)
                queue.append(neighbor)

    return visited


class gurobi_MIP_xgboost:

    def __init__(self, model, n_courses):
        self.model = model
        self.n_courses = n_courses
        self.x_variables = []
        self.y_variables = []
        self.objective = []
        self.cu_constraint = []
        self.budget_constraint = []
        self.schedule_constraint = []

        self.parsed_tree = self._parse_xgboost()
        self.splits = self._get_splits()

    def generate_mip(self, credit_units, course_timetable, cu_max=5):
        """
        This MIP is written based on the paper "Optimization of Tree Ensembles (Misic 2019)"
        """
        self.mip = None
        self.mip = gp.Model("xgboost GUROBI MIP")

        self.x_variables = []
        self.y_variables = []
        self.objective = []
        self.cu_constraint = []
        self.budget_constraint = []
        self.schedule_constraint = []

        # declare Tree variables
        for t in range(self.model.n_estimators):
            self.y_variables.append(
                self.mip.addVars([l for l in self.parsed_tree[t]['leafs_values'].keys()], name="y_{}_".format(t),
                                 vtype=GRB.BINARY))

        for split_variable, split_points in self.splits.items():
            self.x_variables.append(
                self.mip.addVars([j for j in range(len(split_points))], name="x_{}_".format(split_variable),
                                 vtype=GRB.BINARY))
        # constraint 2b

        for t in range(self.model.n_estimators):
            self.mip.addConstr(
                gp.quicksum(self.y_variables[t][l] for l in self.parsed_tree[t]['leafs_values'].keys()) == 1)

        # constraint 2e
        self.split_variable_index_map = dict.fromkeys(self.splits.keys())
        counter = 0
        for key, value in self.split_variable_index_map.items():
            self.split_variable_index_map[key] = counter
            counter += 1
        for split_variable, split_points in self.splits.items():
            self.mip.addConstr(gp.quicksum(
                self.x_variables[self.split_variable_index_map[split_variable]][j] for j in range(len(split_points))) <= 1)

        # constraint 2c and 2d
        for t in range(self.model.n_estimators):
            for split_variable_index in range(len(self.parsed_tree[t]['feature_indecies'])):
                s = self.parsed_tree[t]['feature_indecies'][split_variable_index]
                self.mip.addConstr(gp.quicksum(
                    self.y_variables[t][l] for l in self.parsed_tree[t]['reachable_leafs_left'][split_variable_index])
                                   <= gp.quicksum(
                    self.x_variables[self.split_variable_index_map[s]][j] for j in range(len(self.splits[s]))))
                self.mip.addConstr(gp.quicksum(
                    self.y_variables[t][l] for l in self.parsed_tree[t]['reachable_leafs_right'][split_variable_index])
                                   <= 1 - gp.quicksum(
                    self.x_variables[self.split_variable_index_map[s]][j] for j in range(len(self.splits[s]))))

        # constraint 2f
        for split_variable, split_points in self.splits.items():
            for j in range(len(split_points) - 1):
                self.mip.addConstr(
                    self.x_variables[self.split_variable_index_map[s]][j] <= self.x_variables[self.split_variable_index_map[s]][
                        j + 1])

        for t in range(self.model.n_estimators):
            self.objective.append(
                gp.quicksum(self.y_variables[t][l] * value for l, value in self.parsed_tree[t]['leafs_values'].items()))
        self.mip.setObjective(0.5 + gp.quicksum(self.objective[t] for t in range(self.model.n_estimators)),
                              GRB.MAXIMIZE)

        # cu constraint
        for split_variable, split_points in self.splits.items():
            self.cu_constraint.append(gp.quicksum(
                (1 - self.x_variables[self.split_variable_index_map[split_variable]][j]) * credit_units[split_variable] for j
                in range(len(split_points))))
        self.mip.addConstr(gp.quicksum(self.cu_constraint[i] for i in range(len(self.cu_constraint))) <= cu_max,
                           name='cu')

        # for any timeslot of any day, the student can only have one of the courses with a lecture on that timeslot
        for split_variable, split_points in self.splits.items():
            self.schedule_constraint.append(gp.quicksum(
                (1 - self.x_variables[self.split_variable_index_map[split_variable]][j]) for j in range(len(split_points))))
        for day in course_timetable:
            for timeslot in day:
                self.mip.addConstr(
                    gp.quicksum(self.schedule_constraint[self.split_variable_index_map[i]] for i in timeslot) <= 1,
                    name='overlaps')

        return self.mip

    def add_budget_constraint(self, course_prices, budget):
        # if verbose:
        #     self.mip.write(os.path.expanduser('~/Desktop/xgboost_mip_before.lp'))
        self.budget_constraint = []
        try:
            c = self.mip.getConstrByName('budget')
            self.mip.remove(c)
            self.mip.update()
        except:
            print('no budget variable')
        # budget constraint
        for split_variable, split_points in self.splits.items():
            self.budget_constraint.append(gp.quicksum(
                (1 - self.x_variables[self.split_variable_index_map[split_variable]][j]) * course_prices[split_variable] for
                j in range(len(split_points))))
        self.mip.addConstr(gp.quicksum(self.budget_constraint[i] for i in range(len(self.budget_constraint))) <= budget,
                           name='budget')
        return
        # if verbose:
        #     self.mip.write(os.path.expanduser('~/Desktop/xgboost_mip_after.lp'))

    def add_forbidden_bundle(self, bundle):
        self.mip.addConstr(gp.quicksum((1-self.x_variables[i][0]) * bundle[i] for i in range(len(self.x_variables))) <= np.sum(bundle) - 0.1, name = 'alreadyQueried')
        self.mip.update()
        return

    def solve_mip(self, outputFlag = False, verbose = False):
        start = timer()

        self.mip.Params.OutputFlag = outputFlag
        self.mip.optimize()
        end = timer()

        if (verbose):
            self.mip.write(os.path.expanduser('~/Desktop/xgboost_mip.lp'))
            print(f'xgboost MIP solved in: {end - start}')
            print(f'The value of the optimal solution is: {self.mip.getObjective().getValue()}')
        # for v in self.mip.getVars():
        #     print('%s %g' % (v.varName, v.x))
        optimal_schedule = []
        for i in range(len(self.x_variables)):
            for j in range(len(self.x_variables[i])):
                if self.x_variables[i][j].x < 0.99:
                    optimal_schedule = optimal_schedule + [int(self.x_variables[i][j].varName.split("_")[1])]

        optimal_schedule_0_1 = list(np.zeros(self.n_courses))
        for i in optimal_schedule:
            optimal_schedule_0_1[i] = 1
        return np.array(optimal_schedule_0_1), self.mip.getObjective().getValue()

    def _parse_xgboost(self):
        """
        The function parses the xgboost into a list of its subtress
        return:
         a dictionary with an element for each subtree of xgboost
         For each subtree a dictionary with the following parameters is returned:
         leafs_values: Value of the leafs
         leafs_count: Number of the leafs on the subtree
         featire_indecies: The indecies of the feautures used for splits in X_train
         feauture_splits: The values over which the feautures split
         reachable_leafs_right: Reachable leafs from the right branch of each splitting feature
         reachable_leafs_left: Reachable leafs from the left branch of each splitting feature

        """
        bstr = self.model.get_booster()
        #get all the subtrees ad their parameters
        param_df = bstr.trees_to_dataframe()
        # number of trees used in xgboost
        n_estimators = self.model.n_estimators
        parsed_tree = dict.fromkeys(list(range(n_estimators)))

        for i in range(n_estimators):
            # Take a single tree
            current_tree = param_df[param_df.Tree == i]
            index_map = dict.fromkeys(current_tree['ID'].values, None)
            index_map_value = 0
            #mapping 'Tree_nodenumber' to 'nodenumber
            for key, value in index_map.items():
                index_map[key] = index_map_value
                index_map_value += 1

            Tree = []
            #separate the nodes that are maked as leafs
            current_tree_leafs = current_tree[current_tree.Feature == 'Leaf'].reset_index()
            # Splitting nodes
            current_tree_splits = current_tree[current_tree.Feature != 'Leaf'].reset_index()

            parsed_tree[i] = dict()

            for j in current_tree.index:
                # parse nodes in the tree
                current_node = current_tree.loc[j, ]
                left = -1 if (type(current_node['Yes']) == float or type(current_node['Yes']) == np.float64) and math.isnan(current_node['Yes']) else int(index_map[current_node['Yes']])
                right = -1 if (type(current_node['No']) == float or type(current_node['No']) == np.float64) and math.isnan(current_node['No']) else int(index_map[current_node['No']])
                Tree = Tree + [Node(index = current_node['Node'], left = left, right = right)]

            leafs_idx = current_tree_leafs['Node'].values

            leafs_index_map = dict.fromkeys(leafs_idx)
            leafs_reverse_index_map = dict()
            k = 0
            for key, value in leafs_index_map.items():
                leafs_index_map[key] = k
                leafs_reverse_index_map[k] = key
                k += 1

            parsed_tree[i]['leafs_values'] = dict.fromkeys(leafs_reverse_index_map.keys())
            parsed_tree[i]['leafs_count'] = len(current_tree_leafs['Node'].values)
            # parsing the value of each leaf
            for key, value in parsed_tree[i]['leafs_values'].items():
                parsed_tree[i]['leafs_values'][key] = float(current_tree_leafs[current_tree_leafs.Node == leafs_reverse_index_map[key]]['Gain'].values)

            parsed_tree[i]['feature_indecies'] = list(current_tree_splits['Feature'].values)
            parsed_tree[i]['feature_indecies'] = [int(x[1:]) for x in parsed_tree[i]['feature_indecies']]
            # The value over which the split node splits
            parsed_tree[i]['feature_splits'] = list(current_tree_splits['Split'].values)

            reachable_leafs_left_list = []
            reachable_leafs_right_list = []

            # Add the reachable leafs from the left branch and right branch
            for j in current_tree_splits.index:
                node = int(current_tree_splits.loc[j, 'Node'])
                reachable_leafs_right_list = reachable_leafs_right_list + [[leafs_index_map[x] for x in find_reachble_leaves(Tree, Tree[node].right) if x in leafs_idx]]
                reachable_leafs_left_list = reachable_leafs_left_list + [[leafs_index_map[x] for x in find_reachble_leaves(Tree, Tree[node].left) if x in leafs_idx]]

            parsed_tree[i]['reachable_leafs_right'] = reachable_leafs_right_list
            parsed_tree[i]['reachable_leafs_left'] = reachable_leafs_left_list

        split_varibles_in_xgboost = []
        for t in range(self.model.n_estimators):
            split_varibles_in_xgboost = split_varibles_in_xgboost + parsed_tree[t]['feature_indecies']
        feautures_not_in_xgboost = list(set(range(self.n_courses)) - set(split_varibles_in_xgboost))
        for i, v in enumerate(feautures_not_in_xgboost):
            parsed_tree[n_estimators + i] = dict()
            parsed_tree[n_estimators + i]['leafs_values'] = [0, 0]
            parsed_tree[n_estimators + i]['leafs_count'] = [2]
            parsed_tree[n_estimators + i]['feature_indecies'] = [v]
            parsed_tree[n_estimators + i]['feature_splits'] = [0.5]
            parsed_tree[n_estimators + i]['reachable_leafs_right'] = [1]
            parsed_tree[n_estimators + i]['reachable_leafs_left'] = [0]

        return parsed_tree

    def _get_splits(self):
        """
        This function retursn a dictionary with keys of the index of split variables in all xgboost trees in X_train and the corresponding valuesthey split over
        """
        split_variables = []
        for t in self.parsed_tree.keys():
            split_variables = split_variables + self.parsed_tree[t]['feature_indecies']
        split_variables = list(set(split_variables))

        splits = dict.fromkeys(split_variables)
        for key, value in splits.items():
            splits[key] = []

        for t in self.parsed_tree.keys():
            for s_index in range(len(self.parsed_tree[t]['feature_indecies'])):
                splits[self.parsed_tree[t]['feature_indecies'][s_index]] = splits[self.parsed_tree[t]['feature_indecies'][s_index]] + \
                                                                           [self.parsed_tree[t]['feature_splits'][s_index]]

        for key, value in splits.items():
            value = list(set(value))
            splits[key] = list((np.sort(value)))
        # print(splits)
        return splits
