from pyepo.func import SPOPlus, blackboxOpt, negativeIdentity, perturbedOpt, perturbedFenchelYoung, implicitMLE, NCE, contrastiveMAP, pointwiseLTR, pairwiseLTR, listwiseLTR
from data.data_factory import opt_data, opt_data_load
from models.Pred import PredModel
import torch
from utils.metrics import metric
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from utils.evaluate import calRegret
import torch.nn as nn
matplotlib.use('TkAgg')
plt.rcParams['figure.figsize'] = (9, 6)
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
parameters = {'axes.labelsize': 18, 'axes.titlesize': 18, 'legend.fontsize': 18, 'xtick.labelsize': 18,
              'ytick.labelsize': 18}
plt.rcParams.update(parameters)
plt.rcParams['axes.unicode_minus'] = False



class Experiment:
    def __init__(self, epochs, patience, method, mode, lr=0.00005):
        # initialization
        self.device = torch.device('cuda:0')
        self.epochs = epochs
        self.patience = patience
        self.pred_model_obj = predModel_obj.to(self.device)
        # self.pred_model_constr = predModel_constr.to(self.device)
        self.learning_rate = lr
        self.method = method
        self.mode = mode
    def test(self):
        model_obj = self.pred_model_obj
        model_obj.load_state_dict(
            torch.load('./results/model_logs/PredictModel_{}_{}_{}_best_model.pkl'.format(d_type, r_name, o_name)))
        model_obj.eval()

        iter_count = 0
        loss_regret = 0
        loss_regret_list = []
        optsum = 0
        preds, trues, stds, sols_pred, sols_true, objs, true_objs = [], [], [], [], [], [], []
        model_obj.eval()
        # model_constr.eval()
        for i, data in enumerate(optDataloader_test):
            iter_count += 1
            batch_x = data[0].float().to(self.device)
            batch_y = data[1].float().to(self.device)
            w = data[2].float().to(self.device)
            z = data[3].float().to(self.device)
            # 网络前向传播
            model_out = model_obj(batch_x).detach().cpu().numpy()
            true = batch_y.detach().cpu().numpy()
            true_obj = w * batch_y
            # print('true obj shape:', true_obj.shape)
            # regret_calculate
            for j in range(model_out.shape[0]):
                # accumulate loss
                rege, sol, obj = calRegret(optModel, model_out[j], true[j], z[j].item())
                loss_regret = loss_regret + rege
                loss_regret_list.append(rege)  # b[j]
                sols_pred.append(sol)  # w^(\hat{c})
                objs.append(obj)
            optsum += abs(z).sum().item()

            preds.append(model_out)
            trues.append(true)
            sols_true.append(w.detach().cpu().numpy())
            true_objs.append(true_obj.detach().cpu().numpy())

        preds = np.array(preds)
        if np.array(trues).shape[2] == 1:
            trues = np.array(trues).squeeze(axis=2)
        sols_pred = np.array(sols_pred)
        sols_true = np.array(sols_true)
        objs = np.array(objs)
        true_objs = np.array(true_objs)
        # print('shape of preds:{}, \n trues:{}, \n sols_pred:{} \n sols_true:{} \n objs:{} \n true_objs:{}'.format(np.array(preds).shape, trues.shape, sols_pred.shape, sols_true.shape, objs.shape, true_objs.shape))

        preds = preds[0, 0, :]
        trues = np.array(trues)[0, 0, :]
        sols_pred = sols_pred[0, :]
        sols_true = sols_true[0, 0, :]
        objs = objs[0, :]
        objs_true = true_objs[0, 0, :]
        # print(preds)
        # print(trues)
        mae, mse, rmse, mape, mspe, r2 = metric(preds.reshape(-1, 1), trues.reshape(-1, 1))
        print('MAE,MSE,RMSE,MAPE分别为{},{},{},{}, r2:{}m\n'.format(mae, mse, rmse, mape, r2))
        print('Regret evaluation for {}:'.format(self.method), loss_regret, loss_regret / (optsum + 1e-7))
        print('\n')

        # df_out = pd.DataFrame()
        # df_out['Pred_mean'] = preds.flatten()
        # df_out['Pred_std'] = stds.flatten()
        # df_out['True'] = trues.flatten()
        # df_out.to_csv(f'./results/data/SSPO/SSPO_prediction_{mae:.2f}.csv', index=False)

        regret_array = np.array(loss_regret_list)
        regret_out = pd.DataFrame(regret_array)
        regret_out.to_csv(f'./results/evaluation/regret_loss_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)
        # df_out = pd.DataFrame(objs.reshape(60, 18))  # c^T w(\hat{c})
        # df_out.to_csv(f'./results/evaluation/obj_horizon_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)
        # df_out = pd.DataFrame(objs_true.reshape(60, 18))  # c^T w(c)
        # df_out.to_csv(f'./results/evaluation/obj_true_{self.method}.csv', index=False)
        # df_out = pd.DataFrame(preds.reshape(60, 1))
        # df_out.to_csv(f'./results/evaluation/prediction_pred_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)
        # df_out = pd.DataFrame(trues.reshape(60, 1))
        # df_out.to_csv(f'./results/evaluation/prediction_true_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)
        # df_out = pd.DataFrame(sols_pred.reshape(60, 18))
        # df_out.to_csv(f'./results/evaluation/solution_pred_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)
        # df_out = pd.DataFrame(sols_true.reshape(60, 18))
        # df_out.to_csv(f'./results/evaluation/solution_true_{self.method}_{self.mode}_{mae:.2f}.csv', index=False)



if __name__ == "__main__":
    def init_weights(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    torch.manual_seed(2024)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(2024)
        torch.cuda.manual_seed_all(2024)  # 如果有多个 GPU
    data_type = ['shortestpath', 'portfolio', 'knapsack']  # , 'alloyproduction'
    regret_type = [implicitMLE, negativeIdentity, contrastiveMAP, SPOPlus, blackboxOpt]  # , implicitMLE, NCE,
    regret_name = ['IMLE', 'NID', 'CMAP', 'SPO', 'DBB']  # 'NCE',
    # optimizer_type = [Adam, Adadelta, Adagrad, AdamW, RMSprop, APGO, APGO, APGO_NoAdaptive, APGO_NoMomentum, APGO]# AARMSpropW,  #SGD, ASGD,
    optimizer_name = ['Adam', 'Adadelta', 'Adagrad', 'AdamW', 'RMSpp', 'APGO', 'APGO_NoProx', 'APGO_NoAdaptive', 'APGO_NoMomentum', 'APGO_NoWeightDecay'] #'APGO', #'SGD', 'ASGD',
    for i, d_type in enumerate(data_type):
        print('\n')
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        print('Dataset:', d_type)
        # init optimization model
        # optModel = opt_model(d_type)   # 2024.5.17对模型和数据集做alignment的时候，将模型初始化放在了data_factory中
        build_optdata = 0
        if build_optdata == 1:
            optModel, optDataloader_train, optDataloader_vali, optDataloader_test, dataset_train, dataset_vali, dataset_test, x_shape, c_shape = opt_data(d_type, batch=16)

        else:
            optModel, optDataloader_train, optDataloader_vali, optDataloader_test, dataset_train, dataset_vali, dataset_test, x_shape, c_shape = opt_data_load(d_type, batch=16)
        print('input size =', x_shape)
        print('output size =', c_shape)
        for j, r_type in enumerate(regret_type):
            for o, o_type in enumerate(optimizer_type):
                # init prediction model
                if d_type == 'alloyproduction':
                    print('alloyproduction!')
                    predModel_obj = PredModel(input_size=x_shape, output_size=c_shape)
                else:
                    predModel_obj = PredModel(input_size=x_shape, output_size=c_shape)
                predModel_obj.apply(init_weights)

                torch.cuda.empty_cache()
                torch.cuda.empty_cache()
                r_name = regret_name[j]
                o_name = optimizer_name[o]
                print('Regret:', r_name)
                print('Optimier:', o_name)

                regret = r_type
                Optimizer = o_type
                if r_name in ['NID', 'DBB']:
                    lr = 0.000001
                else:
                    lr = 0.0001
                try:
                    exp = Experiment(epochs=50, patience=5, method=o_name, mode=r_name, lr=lr)
                    exp.test()
                except:
                    print(d_type, r_name, o_name, 'No such file')

