import os
import sys
import numpy as np
import random
import pandas as pd
import math, time
import itertools
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
import datetime
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
import gurobipy as gp
import logging
import copy
from collections import defaultdict
import joblib
import gurobipy as gp
from gurobipy import GRB
import linear_relax as LP_relax_file
import ip_model_both_tricks as ip_model_wholeFile
from ip_model_both_tricks import IPOfunc
from ip_model_both_tricks import solve_LP

featureNum = 8
warm_start_epoch_criterion = 20
stop_epoch_criterion = 40
trainmarkNum = LP_relax_file.trainmarkNum
facility_num = LP_relax_file.facility_num
ERU_num = LP_relax_file.ERU_num
var_num = LP_relax_file.var_num


pT = float(sys.argv[1])
LP_relax_file.set_penaltyTerm(pT)
penaltyTerm = LP_relax_file.penaltyTerm
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])
warmstart_val = 0
lr_for_intpo = 0
if pT < 2:
  lr_for_mse = 1e-4
  lr_for_intpo = 1e-6
else:
  lr_for_mse = 1e-3
  lr_for_intpo = 1e-7

prior_folder = os.path.abspath(os.path.dirname(os.getcwd()))
dataset_path = os.path.join(prior_folder, 'data(facility_num=' + str(facility_num) + ',ERU_num' + str(ERU_num) +')/')
LP_relax_file.mkdir(dataset_path, 'Combination')
store_file_path = os.path.join(dataset_path, 'Combination/')
    
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

def make_fc(num_layers, num_features, num_targets=1,
            activation_fn = nn.ReLU,intermediate_size=2*featureNum, regularizers = True):
    net_layers = [nn.Linear(num_features, intermediate_size),activation_fn()]
#    net_layers = [nn.Linear(num_features, intermediate_size)]
    for hidden in range(num_layers-2):
        net_layers.append(nn.Linear(intermediate_size, intermediate_size))
        net_layers.append(activation_fn())
    net_layers.append(nn.Linear(intermediate_size, num_targets))
    net_layers.append(activation_fn())
    return nn.Sequential(*net_layers)
        

class MyCustomDataset():
    def __init__(self, feature, value):
        self.feature = feature
        self.value = value

    def __len__(self):
        return len(self.value)

    def __getitem__(self, idx):
        return self.feature[idx], self.value[idx]


class Intopt:
    def __init__(self, cost, A, b, avail_matrices, neg_G, penaltyTerm, n_features, num_layers=5, smoothing=False, thr=0.1, max_iter=None, method=1, mu0=None,
                 damping=1e-7, target_size=1, epochs=8, optimizer=optim.Adam,
                 batch_size=facility_num, **hyperparams):
        
        self.cost = cost
        self.A = A
        self.b = b
        self.avail_matrices = avail_matrices
        self.neg_G = neg_G
        self.penaltyTerm = penaltyTerm

        self.target_size = target_size
        self.n_features = n_features
        self.damping = damping
        self.num_layers = num_layers

        self.smoothing = smoothing
        self.thr = thr
        self.max_iter = max_iter
        self.method = method
        self.mu0 = mu0

        self.optimizer = optimizer
        self.batch_size = batch_size
        self.hyperparams = hyperparams
        self.epochs = epochs
        # print("embedding size {} n_features {}".format(embedding_size, n_features))

#        self.model = Net(n_features=n_features, target_size=target_size)
        self.model = make_fc(num_layers=self.num_layers,num_features=n_features)
        #self.model.apply(weight_init)
#        w1 = self.model[0].weight
#        print(w1)

        self.optimizer = optimizer(self.model.parameters(), **hyperparams)


    def fit(self, feature, value):
        logging.info("Intopt")
        train_df = MyCustomDataset(feature, value)

        criterion = nn.MSELoss(reduction='mean')  # nn.MSELoss(reduction='mean')
        grad_list = np.zeros(self.epochs)
        cur_tov = np.zeros(trainmarkNum)
        for e in range(self.epochs):
          each_epoch_start = time.time()
          total_loss = 0
#          for parameters in self.model.parameters():
#            print(parameters)
          if e < warm_start_epoch_criterion:
            #print('stage 1')
            train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
            for feature, value in train_dl:
                self.optimizer.zero_grad()
                op = self.model(feature).squeeze()
#                while torch.min(op) <= 0 or torch.isnan(op).any() or torch.isinf(op).any():
#                    self.optimizer.zero_grad()
##                    self.model.__init__(self.n_features, self.target_size)
#                    self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
#                    op = self.model(feature).squeeze()
                #print(op)
                
                loss = criterion(op, value)
                total_loss += loss.item()
                loss.backward()
                self.optimizer.step()
                global stop_epoch
                stop_epoch = e
#            if e < warm_start_epoch_criterion - 1:
#                print("{} ->".format(total_loss), end=" ")
            if e == warm_start_epoch_criterion - 1:
                print("{} ->".format(total_loss))
                curr_PReg, cur_tov = self.val_loss(feature_train, value_train)
                curr_avgPReg = np.mean(curr_PReg)
                print("curr_loss: ", curr_avgPReg)
                
          else:
            if e == warm_start_epoch_criterion:
                lr = lr_for_intpo
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = lr
#            for param_group in self.optimizer.param_groups:
#                print(param_group['lr'], end=' ')
#            print('\n')
            #print('stage 2')
            train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
            instance_num = 0
            
            for feature, value in train_dl:
                self.optimizer.zero_grad()
                op = self.model(feature).squeeze()
                while torch.min(op) <= 0 or torch.isnan(op).any() or torch.isinf(op).any():
                    self.optimizer.zero_grad()
#                    self.model.__init__(self.n_features, self.target_size)
                    self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                    op = self.model(feature).squeeze()
                
                c_torch = torch.from_numpy(self.cost).float()
                G_torch = torch.from_numpy(self.neg_G).float()
                A_torch = torch.from_numpy(self.A).float()
                b_torch = torch.from_numpy(self.b).float()
  
                x_s2 = IPOfunc(cost=self.cost, avail_matrices=self.avail_matrices, c=c_torch, A=A_torch, b=b_torch, G=G_torch, true_req=value, penalty=penaltyTerm, max_iter=self.max_iter, thr=self.thr, damping=self.damping,
                            smoothing=self.smoothing)(op)
                x_s1 = ip_model_wholeFile.x_s1
#                print(x_s2)
                obj_pen = self.penaltyTerm * self.cost
                obj_pen_torch = torch.from_numpy(obj_pen).float()
                loss = (x_s2 * c_torch).sum() + (obj_pen_torch * (x_s2 - x_s1)).sum() - cur_tov[instance_num]
                
                total_loss += loss.item()
                loss.backward()
                self.optimizer.step()
                instance_num = instance_num + 1
                stop_epoch = e
            total_loss = total_loss / instance_num
            print("Epoch{} ::loss {} ".format(e,total_loss), end=" ")
                
          logging.info("EPOCH Ends")
          each_epoch_end = time.time()
          if e >= warm_start_epoch_criterion:
            print("time: ", each_epoch_end-each_epoch_start)
          grad_list[e] = total_loss
          if e > 0 and abs(grad_list[e] - grad_list[e-1]) < 1e-2:
            break
          if e >= warm_start_epoch_criterion and grad_list[e] >= grad_list[e-1]:
            break
          else:
            global bestTrainCorrReg
            if grad_list[e] < bestTrainCorrReg:
              bestTrainCorrReg = grad_list[e]
              torch.save(self.model.state_dict(), 'Combination-' + str(penaltyTerm) + '-model.pkl')


    def val_loss(self, feature, value):
        valueTemp = value.numpy()
        test_instance = len(valueTemp) / self.batch_size
#        test_instance = 1
        real_obj = LP_relax_file.actual_obj(self.cost, self.avail_matrices, value, n_instance=int(test_instance))
#        print(real_obj)

        self.model.eval()
        criterion = nn.MSELoss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)
        prediction_loss = 0
        corr_obj_list = []
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
#            print(value,op)
            loss = criterion(op, value)
            prediction_loss += loss.item()

            true_req = np.zeros(facility_num)
            pred_req = np.zeros(facility_num)
            for i in range(facility_num):
                true_req[i] = value[i]
                pred_req[i] = op[i]
            
            corrrlst = LP_relax_file.correction_single_obj(self.cost, self.avail_matrices, true_req, pred_req, self.penaltyTerm)
            corr_obj_list.append(corrrlst)
            
            num = num + 1

        self.model.train()
#        print("corr_obj_list: ", corr_obj_list)
#        print("2SReg: ", real_obj - np.array(corr_obj_list))
#        return prediction_loss, abs(np.array(obj_list) - real_obj)
#        print(np.mean(real_obj))
        return abs(real_obj - np.array(corr_obj_list)), real_obj


A_data = np.zeros((2, var_num))
b_data = np.zeros(2)
print("**** Combination ****")
testTime = 10
recordBest = np.zeros((1, testTime))


print("facility_num: ", facility_num, "ERU_num: ", ERU_num, "penalty: ", penaltyTerm)
    
for testi in range(startmark, endmark):
    print(testi)
    stop_epoch = 0
    cost = np.loadtxt(os.path.join(dataset_path,'cost/cost(' + str(testi) + ').txt'))
    avail_matrices = np.loadtxt(os.path.join(dataset_path,'avail_matrices/avail_matrices(' + str(testi) + ').txt'))
    neg_G = LP_relax_file.gen_G(avail_matrices)
    
    trainData = np.loadtxt(os.path.join(dataset_path,'train/train(' + str(testi) + ').txt'))
    x_train = trainData[:, 1:featureNum+1]
    y_train = trainData[:, featureNum+1]
    feature_train = torch.from_numpy(x_train).float()
    value_train = torch.from_numpy(y_train).float()
    warmstart_val = np.mean(y_train)

    testData = np.loadtxt(os.path.join(dataset_path,'test/test(' + str(testi) + ').txt'))
    x_test = testData[:, 1:featureNum+1]
    y_test = testData[:, featureNum+1]
    feature_test = torch.from_numpy(x_test).float()
    value_test = torch.from_numpy(y_test).float()

    damping = 1e-7
    thr = 1e-3
    lr = lr_for_mse   # NN
    bestTrainCorrReg = float("inf")
    max_retrain_time = 10
    # for j in range(1):
    start_time = time.time()
    while stop_epoch < warm_start_epoch_criterion + 2 and max_retrain_time > 0:
        max_retrain_time = max_retrain_time - 1
        clf = Intopt(cost, A_data, b_data, avail_matrices, neg_G, penaltyTerm, damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=stop_epoch_criterion)
        clf.fit(feature_train, value_train)
        train_rslt, train_tov = clf.val_loss(feature_train, value_train)
        avgTrainCorrReg = np.mean(train_rslt)
    #    trainHSD_rslt = str(testmark) + ' train: ' + str(np.sum(train_rslt[1])) + ' ' + str(np.mean(train_rslt[1]))
        trainHSD_rslt = ' train: ' + str(np.mean(train_rslt))
        print(trainHSD_rslt)
    
        if stop_epoch >= warm_start_epoch_criterion or max_retrain_time <= 0:
            end_time = time.time()
            bestTrainCorrReg = avgTrainCorrReg
#            torch.save(clf.model.state_dict(), '1CPLEX_model.pkl')
            

    clfBest = Intopt(cost, A_data, b_data, avail_matrices, neg_G, penaltyTerm, damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=6)
    clfBest.model.load_state_dict(torch.load('Combination-' + str(penaltyTerm) + '-model.pkl'))
#
    value = clfBest.model(feature_test).squeeze()
    value = value.detach().numpy()
    pred_value = np.zeros((value.size, 3))
    for i in range(value.size):
        pred_value[i][0] = int(i/facility_num)
        pred_value[i][1] = value_test[i]
        pred_value[i][2] = value[i]
    np.savetxt(os.path.join(store_file_path,'Combination-' + str(penaltyTerm) + '(' + str(testi) + ').txt'), pred_value, fmt="%.2f")

    val_rslt, test_tov = clfBest.val_loss(feature_test, value_test)
    #HSD_rslt = str(testmark) + ' test: ' + str(np.sum(val_rslt[0])) + ' ' + str(np.sum(val_rslt[1]))
    HSD_rslt = str(testi) + ' MSE: ' + str(mean_squared_error(value_test, value)) + ' PReg: ' + str(np.mean(val_rslt))  + ' TOV: ' + str(np.mean(test_tov)) + ' Time: ' + str(end_time - start_time)
#    HSD_rslt = str(testi) + ' MSE: ' + str(mean_squared_error(value_test, value)) + ' Time: ' + str(end_time - start_time)
    print(HSD_rslt)
    recordBest[0][testi] = np.sum(val_rslt)

print(recordBest)
