import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global factory_num
global customer_num
global penaltyTerm
global max_total_demand
global var_num
global TOV_location
global diff_location
global testmarkNum

trainmarkNum = 210
testmarkNum = 90
facility_num = 10
ERU_num = 50
penaltyTerm = 0
var_num = ERU_num
#max_for_each_ERU = 50

def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)


def set_penaltyTerm(p):
    global penaltyTerm
    penaltyTerm = p


def gen_G(avail_matrices):
    G = np.zeros((facility_num, var_num))
    for j in range(facility_num):
        for i in range(ERU_num):
            G[j][i] = avail_matrices[i*facility_num+j]
    return -G


def solve_IP_in_matrix_form(c, G, coverage_full, n_instance):
    obj_list = []
    c = c.tolist()
    G = G.tolist()
    
    for num in range(n_instance):
        neg_coverage = np.zeros(facility_num)
        cnt = num * facility_num
        for i in range(facility_num):
            neg_coverage[i] = -coverage_full[cnt]
            cnt = cnt + 1
        neg_coverage = neg_coverage.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(ERU_num, vtype=GRB.INTEGER, name='x')
        
        m.setObjective(x.prod(c), GRB.MINIMIZE)
        for j in range(facility_num):
            m.addConstr(x.prod(G[j]) <= neg_coverage[j])
#        for i in range(ERU_num):
#            m.addConstr(x[i] <= max_for_each_ERU)

        m.optimize()
        x_sol = np.zeros(ERU_num, dtype='i')
        
        for i in range(ERU_num):
            x_sol[i] = x[i].x
#        print(x_sol)
        objective = m.objVal
        obj_list.append(objective)
                
    return np.array(obj_list)


def actual_obj(cost, avail_matrices, coverage_full, n_instance):
    obj_list = []
    cost = cost.tolist()
    avail_matrices = avail_matrices.tolist()
    
    
    for num in range(n_instance):
        coverage = np.zeros(facility_num)
        cnt = num * facility_num
        for i in range(facility_num):
            coverage[i] = coverage_full[cnt]
            cnt = cnt + 1
        coverage = coverage.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(ERU_num, vtype=GRB.INTEGER, name='x')
        
        m.setObjective(x.prod(cost), GRB.MINIMIZE)
        for j in range(facility_num):
            m.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * x[i]) for i in range(ERU_num))) >= coverage[j])
#        for i in range(ERU_num):
#            m.addConstr(x[i] <= max_for_each_ERU)

        m.optimize()
        x_sol = np.zeros(ERU_num, dtype='i')
        
        for i in range(ERU_num):
            x_sol[i] = x[i].x
#        print(x_sol)
        objective = m.objVal
        obj_list.append(objective)
                
    return np.array(obj_list)


def get_stage1_IP_sol(cost, avail_matrices, coverage):
    cost = cost.tolist()
    avail_matrices = avail_matrices.tolist()
    coverage = coverage.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(ERU_num, vtype=GRB.INTEGER, name='x')
    
    m.setObjective(x.prod(cost), GRB.MINIMIZE)
    for j in range(facility_num):
        m.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * x[i]) for i in range(ERU_num))) >= coverage[j])
#        for i in range(ERU_num):
#            m.addConstr(x[i] <= max_for_each_ERU)
    try:
        m.optimize()
        x_sol = np.zeros(ERU_num, dtype='i')
        stage1_solvable = True
        for i in range(ERU_num):
            x_sol[i] = x[i].x
    except:
        x_sol = np.zeros(ERU_num, dtype='i')
        stage1_solvable = True
    return x_sol, stage1_solvable


def correction_single_obj(cost, avail_matrices, true_coverage, pred_coverage, penaltyTerm):
#    print("realPrice: ", realPrice)
    stage2_cost = (1+penaltyTerm)*cost
    stage2_cost = stage2_cost.tolist()
    cost = cost.tolist()
    avail_matrices = avail_matrices.tolist()
    
    objective = 0
    if min(pred_coverage) >= 0:
        # Stage 1:
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(ERU_num, vtype=GRB.INTEGER, name='x')
        
        m.setObjective(x.prod(cost), GRB.MINIMIZE)
        for j in range(facility_num):
            m.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * x[i]) for i in range(ERU_num))) >= pred_coverage[j])
#        for i in range(ERU_num):
#            m.addConstr(x[i] <= max_for_each_ERU)

        m.optimize()
        s1_x_sol = np.zeros(ERU_num, dtype='i')
        for i in range(ERU_num):
            s1_x_sol[i] = x[i].x
#        print("Stage 1: ", predSol, objective1)

        # Stage 2:
        m2 = gp.Model()
        m2.setParam('OutputFlag', 0)
        y = m2.addVars(ERU_num, vtype=GRB.INTEGER, name='y')

        fixed_cost = np.dot(s1_x_sol, cost)
        m2.setObjective(fixed_cost+y.prod(stage2_cost), GRB.MINIMIZE)

        for j in range(facility_num):
            m2.addConstr((gp.quicksum((avail_matrices[i*facility_num+j] * (y[i]+s1_x_sol[i])) for i in range(ERU_num))) >= true_coverage[j])
#        for i in range(ERU_num):
#            m2.addConstr(y[i] <= max_for_each_ERU)
        
        m2.optimize()
        objective = m2.objVal
        
#        s2_y_sol = np.zeros((factory_num, customer_num))
        s2_y_sol = np.zeros(ERU_num)
        
        for i in range(ERU_num):
            s2_y_sol[i] = y[i].x
        
        objective = m2.objVal
        

    return objective
