import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import random
import numpy as np

from pyepo.func.cspoplus import CSPOPlus
from pyepo.metric.regret import cspo_regret 
from copy import deepcopy
from pyepo.data.cspo_dataset import cspo_optDataset
import os
from datetime import datetime

from matplotlib import pyplot as plt
import time
import math
from sklearn.ensemble import RandomForestRegressor
from pyepo.metric.regret import cspo_regret_rf

list_not_rf_methods = ["mse","mse_T","mse_is",
                "cspo+", 'cspo+_T','cspo+_is',
                'cspo+_ws', "cspo+_ws_T", "cspo+_ws_is",
                'cspo+_mse', "cspo+_mse_T", "cspo+_mse_is",
                'cspo+_mse_ws', "cspo+_mse_ws_T", "cspo+_mse_ws_is"
                ]

class LinearRegression(nn.Module):
    def __init__(self,num_feat,num_item, hidden_dim=20):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(num_feat, num_item)
    def forward(self, x):
        out = self.fc1(x)
        return out


# Define the neural network
class constraint_uncertainty_predictor(nn.Module):
    def __init__(self, num_feat, num_item, hidden_dim=10):
        super(constraint_uncertainty_predictor, self).__init__()
        self.fc1 = nn.Linear(num_feat, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_item)
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return out


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def solve_instances(problem_class,repeat_num, num_data, val_num_data, test_num_data, cp_num_data, params, \
                              method_list, learning_rate, save_file_path, num_epochs=20,
                             save_freq=2, num_process=1, cp_alpha=0.1, test_mode=False, solve_ratio=1):
    # Set seed for reproducibility

    # Unpack params
    num_feat = params.num_feat
    num_item = params.num_item
    dim = 1
    capacity = params.capacity
    weight_deg = params.weight_deg 
    noise_width = params.noise_width
    cost_deg = params.cost_deg
    save_freq = save_freq

    for instance_num in range(repeat_num):
        set_seed(5+instance_num*25)
        print("="*40)
        print(f"Solving Instance number {instance_num}")
        print("="*40)
        # Generate new instance
        problem_instance = problem_class(params=params, num_data=num_data, val_num_data = val_num_data, test_num_data=test_num_data, cp_num_data=cp_num_data,
                                              instance_num=instance_num, cp_alpha = cp_alpha, test_mode=test_mode)
        # Define new regressor for conformal prediction
        cp_model = constraint_uncertainty_predictor(num_feat,num_item)
        # Generate problem sets
        train_optmodel_list_trunc, loader_train_trunc, val_optmodel_list, loader_val, test_optmodel_list, loader_test, no_trunc_train_optmodel_list, loader_no_trunc, trained_cp_model, \
              robust_nominal_mean, robust_nominal_std, robust_nominal_infeas_count = problem_instance.gen_cspo_problem_set(cp_model,score_function)
        
        # Solve nominal problems
        nominal_mean, nominal_std = problem_instance.solve_nominal_problem()
        # Solve nominal problems with predicted weights
        nominal_mean_pred, nominal_std_pred, nominal_infeas_count_pred = problem_instance.solve_nominal_with_predicted_weights(trained_cp_model)

       
        
        # Define Loss function
        for method_name in method_list:
            print("="*40)
            print("="*40)
            print(f"Training {method_name} model...\n")

            if method_name.endswith('_T') or method_name.endswith('_is'):
                print(f"Truncation is used for method :{method_name}")
                train_optmodel_list = train_optmodel_list_trunc
                loader_train = loader_train_trunc
                training_index = problem_instance.training_index
                train_x = problem_instance.train_x[training_index]
                train_c = problem_instance.train_c[training_index]
                val_x = problem_instance.val_x
                val_c = problem_instance.val_c
                truncation = True
            else:
                print(f"No truncation is used for method :{method_name}")
                train_optmodel_list = no_trunc_train_optmodel_list
                loader_train = loader_no_trunc
                train_x = problem_instance.train_x
                train_c = problem_instance.train_c
                val_x = problem_instance.val_x
                val_c = problem_instance.val_c
                truncation = False

             # Importance Sampling
            print("Importance Sampling size: ")
            importance_weight = problem_instance.importance_sampling()
            importance_weight = torch.from_numpy(importance_weight).float().view(-1,1)

            if 'is' in method_name and not truncation:
                raise ValueError("Importance Sampling can only be used with no truncation")
            
            dataset = cspo_optDataset(train_optmodel_list, train_x, train_c)
            if method_name in ["cspo+","cspo+_T","cspo+_is","cspo+_mse","cspo+_mse_T","cspo+_mse_is"]:
                # 
                loss_function = CSPOPlus(train_optmodel_list, processes=num_process,solve_ratio=solve_ratio ,dataset=dataset)
            elif method_name in ["cspo+_ws","cspo+_ws_T","cspo+_ws_is","cspo+_mse_ws","cspo+_mse_ws_T","cspo+_mse_ws_is",]:
                # 
                loss_function = CSPOPlus(train_optmodel_list, processes=num_process,warm_start=True,solve_ratio=solve_ratio,dataset=dataset)
            elif method_name in ["mse","mse_T"]:
                loss_function = nn.MSELoss()
            elif method_name == "mse_is":
                loss_function = nn.MSELoss(reduction="none")
            elif method_name in ["rf","is_rf"]:
                pass
            else:
                raise ValueError("Invalid method name")
            # Define new regression model
            reg = LinearRegression(num_feat,num_item)
            # cuda
            if torch.cuda.is_available():
                reg = reg.cuda()
            if method_name.startswith("cspo+_mse"):
                train_loss_log, train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, best_model_state, training_time_log = cspo_trainModel_mse_warmstart(reg,
                                                                                            loss_func=loss_function,
                                                                                            method_name=method_name,
                                                                                            train_optmodel_list= train_optmodel_list,
                                                                                            loader_train=loader_train,
                                                                                            val_optmodel_list = val_optmodel_list,
                                                                                            loader_val=loader_val,
                                                                                            importance_weight=importance_weight,
                                                                                            num_epochs=num_epochs,
                                                                                            lr = learning_rate[method_name],
                                                                                            save_freq=save_freq,
                                                                                            )
                # Save the new results
                result_zip = list(zip(train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, training_time_log))
                with open(save_file_path, 'a') as file:
                    for idx, (train_loss_epoch, val_mse, val_cspop, val_regret, training_time) in enumerate(result_zip):
                        file.writelines(f"{instance_num}, {cp_alpha}, {num_data}, {test_num_data}, {num_feat}, {num_item}," + 
                                        f" {weight_deg}, {noise_width}, {cost_deg}, {capacity}, {method_name}, {(idx+1)*save_freq}, {round(train_loss_epoch,4)}," +
                                        f" {round(val_mse,4)}, {round(val_cspop,4)}, {round(val_regret,4)}, {nominal_mean}, {nominal_std}," +
                                        f" {nominal_mean_pred}, {nominal_std_pred}, {nominal_infeas_count_pred}," + 
                                        f" {robust_nominal_mean}, {robust_nominal_std}, {robust_nominal_infeas_count}, 0, 0, 0, {truncation}, {training_time}\n")
            # reg.load_state_dict(mse_best_model_state)
            elif method_name in ["mse","mse_T","mse_is","cspo+", 'cspo+_T','cspo+_is','cspo+_ws', "cspo+_ws_T", "cspo+_ws_is"]:
            # elif method_name in ["ws_cspo+_is","cspo+","mse","cspo+_is","is_mse"]:
                train_loss_log, train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, best_model_state, training_time_log = cspo_trainModel(reg,
                                                                                            loss_func=loss_function,
                                                                                            method_name=method_name,
                                                                                            train_optmodel_list= train_optmodel_list,
                                                                                            loader_train=loader_train,
                                                                                            val_optmodel_list = val_optmodel_list,
                                                                                            loader_val=loader_val,
                                                                                            importance_weight=importance_weight,
                                                                                            num_epochs=num_epochs,
                                                                                            lr = learning_rate[method_name],
                                                                                            save_freq=save_freq,
                                                                                            )
                # Save the new results
                result_zip = list(zip(train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, training_time_log))
                with open(save_file_path, 'a') as file:
                    for idx, (train_loss_epoch, val_mse, val_cspop, val_regret, training_time) in enumerate(result_zip):
                        file.writelines(f"{instance_num}, {cp_alpha}, {num_data}, {test_num_data}, {num_feat}, {num_item}," + 
                                        f" {weight_deg}, {noise_width}, {cost_deg}, {capacity}, {method_name}, {(idx+1)*save_freq}, {round(train_loss_epoch,4)}," +
                                        f" {round(val_mse,4)}, {round(val_cspop,4)}, {round(val_regret,4)}, {nominal_mean}, {nominal_std}," +
                                        f" {nominal_mean_pred}, {nominal_std_pred}, {nominal_infeas_count_pred}," + 
                                        f" {robust_nominal_mean}, {robust_nominal_std}, {robust_nominal_infeas_count}, 0, 0, 0, {truncation}, {training_time}\n")
            # Train Random Forest
            elif method_name in ["rf","is_rf"]:
                train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, rf_list = train_rf(method_name,
                                 train_x, train_c, val_x, val_c, importance_weight, val_optmodel_list, loader_val)
                training_time = 0
                # Save the new results
                result_zip = list(zip(train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log))
                with open(save_file_path, 'a') as file:
                    for idx, (train_loss_epoch, val_mse, val_cspop, val_regret) in enumerate(result_zip):
                        file.writelines(f"{instance_num}, {cp_alpha}, {num_data}, {test_num_data}, {num_feat}, {num_item}," + 
                                        f" {weight_deg}, {noise_width}, {cost_deg}, {capacity}, {method_name}, {num_epochs}, {round(train_loss_epoch,4)}," +
                                        f" {round(val_mse,4)}, {round(val_cspop,4)}, {round(val_regret,4)}, {nominal_mean}, {nominal_std}," +
                                        f" {nominal_mean_pred}, {nominal_std_pred}, {nominal_infeas_count_pred}," + 
                                        f" {robust_nominal_mean}, {robust_nominal_std}, {robust_nominal_infeas_count}, 0, 0, 0, {truncation}, {training_time}\n")
            

            
            # Load the best model parameters
            if method_name in ["rf","is_rf"]:
                test_regret = cspo_regret_rf(rf_list, test_optmodel_list, loader_test)
            elif method_name in list_not_rf_methods:
                reg.load_state_dict(best_model_state)
                test_regret = cspo_regret(reg, test_optmodel_list, loader_test)
            
            print(f"Test Regret: {round(test_regret*100,3)}%\n")

            # if method_name in ["mse","is_mse"]:
            #     mse_best_model_state = deepcopy(reg.state_dict())


            # Calculate robust objective value for test set.
            test_x = problem_instance.test_x
            test_c = problem_instance.test_c
            test_weight = problem_instance.test_weights
            
            if method_name in list_not_rf_methods:
                with torch.no_grad(): # no grad
                    test_x = torch.tensor(test_x).float()
                    if torch.cuda.is_available():
                        test_x = test_x.cuda()
                    cp = reg(test_x).to("cpu").detach().numpy()
            elif method_name in ["rf","is_rf"]:
                cp = np.zeros_like(test_c)
                for i in range(test_c.shape[1]):
                    cp[:,i] = rf_list[i].predict(test_x)
            print("Calculating Robust Objective Value...\n")
            robust_dataset = cspo_optDataset(test_optmodel_list, test_x, cp)
            robust_sols = robust_dataset.sols
            # Need to check whether the robust sols are feasible.
            robust_objs = []
            robust_infeasible_count = 0
            for i in range(len(robust_sols)):
                robust_sols[i]
                total_weight = np.dot(robust_sols[i,:], test_weight[i,:])
                if total_weight > capacity+1e-6:
                    robust_infeasible_count += 1
                    # continue # Skip the infeasible solution
                total_cost = np.dot(robust_sols[i,:], test_c[i,:])
                robust_objs.append(total_cost)
            robust_objs = np.array(robust_objs)
            robust_mean = round(robust_objs.mean(),4)
            robust_std =round(robust_objs.std(),4)


            print(f"Nominal Mean: {nominal_mean}, Nominal Std: {nominal_std}\n")
            print(f"Nominal with predicted weights Mean: {nominal_mean_pred}, Std: {nominal_std_pred}, Infeasible Count {nominal_infeas_count_pred}\n")
            print(f"Robust Nominal Mean: {robust_nominal_mean}, Robust Nominal Std: {robust_nominal_std}, Infeasible Count: {robust_nominal_infeas_count}\n")
            print(f"Robust Mean: {robust_mean}, Robust Std: {robust_std}, Infeasible Count: {robust_infeasible_count}\n")

            # Save the final result
            with open(save_file_path, 'a') as file:
                regret_min_index = val_regret_log.index(min(val_regret_log))
                min_val_regret = round(val_regret_log[regret_min_index],4)
                min_train_loss = round(train_loss_epoch_log[regret_min_index],4)
                min_val_mse_loss = round(val_mse_loss_log[regret_min_index],4)
                min_val_cspop_loss = round(val_cspop_loss_log[regret_min_index],4)
                file.writelines(f"{instance_num}, {cp_alpha}, {num_data}, {test_num_data}, {num_feat}, {num_item}," + 
                                    f" {weight_deg}, {noise_width}, {cost_deg}, {capacity}, {method_name}, {-1}, {min_train_loss}, {min_val_mse_loss}, {min_val_cspop_loss}," +
                                    f" {test_regret}, {nominal_mean}, {nominal_std}," + 
                                    f" {nominal_mean_pred}, {nominal_std_pred}, {nominal_infeas_count_pred}, {robust_nominal_mean}," +
                                    f" {robust_nominal_std}, {robust_nominal_infeas_count}, {robust_mean}, {robust_std}, {robust_infeasible_count}, {truncation}, {training_time}\n")
            
# def run_experiments(knapsack_problem,num_repeat, num_data, val_num_data, test_num_data, cp_num_data, params_list, method_list, learning_rate,
                # save_path, file_name,num_epochs=num_epochs, num_process= num_process, alphas=alpha_list, test_mode=test_mode, truncation=truncation,solve_ratio=solve_ratio)


def run_experiments(problem_class,repeat_num, num_data, val_num_data, test_num_data, cp_num_data, params_list, method_list, learning_rate,
                     save_path, file_name, num_epochs=20, num_process=1, alphas=[0.1], test_mode=False, solve_ratio=1):
    # Create a folder for the results if it doesn't exist
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    full_path = os.path.join(save_path, file_name)

    # Check if the file exists
    if os.path.exists(full_path):
        # Move the file to the prev_results subfolder
        prev_results_folder = os.path.join(save_path, 'prev_results')
        if not os.path.exists(prev_results_folder):
            os.makedirs(prev_results_folder)

        # Rename the file with current time
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        new_file_name = f"{file_name.split('.')[0]}_{current_time}.txt"
        new_full_path = os.path.join(prev_results_folder, new_file_name)

        os.rename(full_path, new_full_path)

    # Save the new results
    with open(full_path, 'w') as file:
        # Writing header
        file.write("instance, alpha, num_data, test_num_data, num_feat, num_item," +
                    " weight_deg, noise_width, cost_deg, capacity, Model, epoch, train_loss, test_mse_loss, test_cspop_loss," + 
                    " test_regret, nominal_mean, nominal_std, nominal_mean_pred, nominal_std_pred, nominal_infeas_count_pred," + 
                    " robust_nominal_mean, robust_nominal_std, robust_nominal_infeas_count, robust_mean, robust_std, robust_infeas_count, truncation, training_time\n")
    
    for idx, params in enumerate(params_list):
        print(f"Running experiments for params number {idx+1}\n")
        for alpha in alphas:
            solve_instances(problem_class,repeat_num, num_data, val_num_data, test_num_data, cp_num_data, params, method_list,
                                     learning_rate, full_path, num_epochs=num_epochs,
                                     num_process = num_process, cp_alpha = alpha, test_mode=test_mode, solve_ratio=solve_ratio)





# ------------------------------------------------------------------------------------------------
# Visualizing Tool
def visLearningCurve(loss_log, loss_log_regret):
    # create figure and subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,4))

    # draw plot for training loss
    ax1.plot(loss_log, color="c", lw=1)
    ax1.tick_params(axis="both", which="major", labelsize=12)
    ax1.set_xlabel("Iters", fontsize=16)
    ax1.set_ylabel("Loss", fontsize=16)
    ax1.set_title("Learning Curve on Training Set", fontsize=16)

    # draw plot for regret on test
    ax2.plot(loss_log_regret, color="royalblue", ls="--", alpha=0.7, lw=1)
    ax2.set_xticks(range(0, len(loss_log_regret), 2))
    ax2.tick_params(axis="both", which="major", labelsize=12)
    ax2.set_ylim(0, 0.5)
    ax2.set_xlabel("Epochs", fontsize=16)
    ax2.set_ylabel("Regret", fontsize=16)
    ax2.set_title("Learning Curve on Test Set", fontsize=16)

    plt.show()


# ------------------------------------------------------------------------------------------------
# Define score function for conformal prediction here.
def score_function(predmodel, x, y):
    predmodel.eval()
    pred = predmodel(x)
    score = torch.linalg.norm(pred - y, ord=1)
    predmodel.train()
    return score.item()


# ------------------------------------------------------------------------------------------------
# train model
def cal_loss(model, loader_val, val_optmodel_list, mse_loss_func, cspop_loss_func):
    """
    Calculates MSE and CSPOP loss on the validation set.
    """

    model.eval()
    val_mse_loss = 0
    val_cspop_loss = 0
    len_data = 0
    with torch.no_grad():
        for i, data in enumerate(loader_val):
            batch_indices, x, c, w, z = data
            batch_indices = batch_indices.tolist()
            selected_models = [val_optmodel_list[i] for i in batch_indices]
            # cuda
            if torch.cuda.is_available():
                x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()
            # forward pass
            cp = model(x)
            # cspo+ loss
            cspop_loss = cspop_loss_func(selected_models, cp, c, w, z)
            # mse loss
            mse_loss = mse_loss_func(cp,c)
            val_mse_loss += mse_loss.item()
            val_cspop_loss += cspop_loss.item()
            len_data+=len(x)
    val_mse_loss = round(val_mse_loss /len_data,4)
    val_cspop_loss = round(val_cspop_loss /len_data,4)
    model.train()
    return val_mse_loss, val_cspop_loss
    

def cspo_trainModel(reg, loss_func, method_name,
                    train_optmodel_list, loader_train,
                    val_optmodel_list, loader_val, importance_weight = None,
                    num_epochs=20, lr=1e-2, save_freq=5, patience=6,
                    ):
    # set adam optimizer
    optimizer = torch.optim.Adam(reg.parameters(), lr=lr)
    # set learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, min_lr=1e-4,verbose=True)
    # # set an early stopping
    # early_stopping = EarlyStopping(patience=patience, verbose=True)
    # train mode
    reg.train()
    # init log
    train_loss_log = []
    train_loss_epoch_log = []
    # test_loss_log_regret = [cspo_regret(reg, test_optmodel_list, loader_test)]
    val_regret_log = []
    val_mse_loss_log = []
    val_cspop_loss_log = []
    time_elapsed_log = []
    val_mse_loss_func = nn.MSELoss()
    val_cspop_loss_func = CSPOPlus(val_optmodel_list,processes=1)
    # Track the val_loss of the model and stop when improvement is not large enough

    min_val_regret = float('inf')  # Initialize min_test_loss
    best_model_state = None  # To store parameters of the model with the best test loss
    
    if torch.cuda.is_available():
        if importance_weight is not None:
            print(f'before function import sample {importance_weight.device}')
            importance_weight = importance_weight.cuda()
            print(f'After {importance_weight.device}')

    # init elpased time
    elapsed = 0
    for epoch in range(num_epochs):
        # if early_stopping.early_stop:
        #     print("Early stopping at epoch {:2}".format(epoch+1))
        #     break
        # start timing
        tick = time.time()
        train_loss = 0
        len_data = 0
        # load data
        for i, data in enumerate(loader_train):
            batch_indices, x, c, w, z = data
            batch_indices = batch_indices.tolist()
            selected_models = [train_optmodel_list[i] for i in batch_indices]
            # cuda
            if torch.cuda.is_available():
                x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()
            # forward pass
            cp = reg(x)
            # cspo+ loss
            if method_name in ["cspo+","cspo+_T",'cspo+_ws', "cspo+_ws_T"]:
                loss = loss_func(selected_models, cp, c, w, z)
            elif method_name in ["cspo+_is", "cspo+_ws_is"]:
                losses = loss_func(selected_models,cp,c,w,z,reduction="none")
                weighted_loss = losses * importance_weight[batch_indices,:]
                loss = weighted_loss.mean()
            # mse loss
            elif method_name in ["mse",'mse_T']:
                loss = loss_func(cp,c)
            elif method_name == "mse_is":
                losses = loss_func(cp,c)
                weighted_loss = losses * importance_weight[batch_indices,:]
                loss = weighted_loss.mean()
            else:
                raise ValueError("Invalid method name")
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # log
            train_loss += loss.item()
            len_data += len(x)
        train_loss = round(train_loss/len_data ,4)
        train_loss_log.append(loss.item())
        # record time
        epoch_time = time.time() - tick
        elapsed+= epoch_time 

        # To fasten the training process, we only print the result every save_freq epochs
        if (epoch+1)%save_freq == 0:
            train_loss_epoch_log.append(train_loss)
            regret = cspo_regret(reg, val_optmodel_list, loader_val)
            val_regret_log.append(regret)
            val_mse_loss, val_cspop_loss = cal_loss(reg, loader_val, val_optmodel_list,val_mse_loss_func,val_cspop_loss_func)
            val_mse_loss_log.append(val_mse_loss)
            val_cspop_loss_log.append(val_cspop_loss)
            time_elapsed_log.append(elapsed)
            # early_stopping(val_cspop_loss)
            # Update learning rate based on validation loss
            scheduler.step(regret)
            print("Epoch {:2},  Train Loss: {:9.4f}, Val MSE Loss: {:7.4f}, Val CSPOP Loss: {:7.4f}, Validation Regret: {:7.4f}%".format(epoch+1,
                                                                                                                                 train_loss, val_mse_loss, val_cspop_loss, regret*100))

            # Save the model with the best test loss
            if regret < min_val_regret:
                min_val_regret = regret
                best_model_state = deepcopy(reg.state_dict())

    print("Total Elapsed Time: {:.2f} Sec. \n".format(elapsed))

    return train_loss_log, train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, best_model_state, time_elapsed_log

def cspo_trainModel_mse_warmstart(reg, loss_func, method_name,
                    train_optmodel_list, loader_train,
                    val_optmodel_list, loader_val, importance_weight = None,
                    num_epochs=20, lr=1e-2, save_freq=5, patience=6,
                    ):
    if not method_name.startswith("cspo+_mse"):
        raise ValueError("Invalid method name")
    # set adam optimizer
    optimizer = torch.optim.Adam(reg.parameters(), lr=lr)
    # set learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3,  min_lr=1e-4,verbose=True)
    # # set an early stopping
    # early_stopping = EarlyStopping(patience=patience, verbose=True)
    # train mode
    reg.train()
    # init log
    train_loss_log = []
    train_loss_epoch_log = []
    # test_loss_log_regret = [cspo_regret(reg, test_optmodel_list, loader_test)]
    val_regret_log = []
    val_mse_loss_log = []
    val_cspop_loss_log = []
    time_elapsed_log = []
    val_mse_loss_func = nn.MSELoss()
    val_cspop_loss_func = CSPOPlus(val_optmodel_list,processes=1)
    # Track the val_loss of the model and stop when improvement is not large enough
    min_val_regret = float('inf')  # Initialize min_test_loss
    best_model_state = None  # To store parameters of the model with the best test loss

    if torch.cuda.is_available():
        if importance_weight is not None:
            print(f'before function import sample {importance_weight.device}')
            importance_weight = importance_weight.cuda()
            print(f'After {importance_weight.device}')

    # init elpased time
    elapsed = 0
    for epoch in range(num_epochs):
        # if early_stopping.early_stop:
        #     print("Early stopping at epoch {:2}".format(epoch+1))
        #     break
        # start timing
        tick = time.time()
        train_loss ,len_data= 0,0
        # load data
        for i, data in enumerate(loader_train):
            batch_indices, x, c, w, z = data
            batch_indices = batch_indices.tolist()
            selected_models = [train_optmodel_list[i] for i in batch_indices]
            # cuda
            if torch.cuda.is_available():
                x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()
            # forward pass
            cp = reg(x)
            if method_name in ["cspo+_mse","cspo+_mse_T",'cspo+_mse_ws', "cspo+_mse_ws_T"]:
                if epoch < 10:
                    loss = nn.MSELoss()(cp,c)
                elif epoch==10:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr * 5  # Increase learning rate by 10x for the first epoch
                    loss = loss_func(selected_models, cp, c, w, z)
                else:
                    loss = loss_func(selected_models, cp, c, w, z)
            elif method_name in ["cspo+_mse_is", "cspo+_mse_ws_is"]:
                if epoch < 10:
                    losses =  nn.MSELoss(reduction="none")(cp,c)
                elif epoch==10:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr * 5 # Increase learning rate by 10x for the first epoch
                    losses = loss_func(selected_models, cp, c, w, z)
                else:
                    losses = loss_func(selected_models,cp,c,w,z,reduction="none")
                weighted_loss = losses * importance_weight[batch_indices,:]
                loss = weighted_loss.mean()
            else:
                raise ValueError("Invalid method name") 
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # log
            train_loss += loss.item()
            len_data += len(x)
            
        train_loss = round(train_loss/len_data ,4)
        train_loss_log.append(loss.item())
         # Calculate epoch runtime after all batches are processed
        epoch_time = time.time() - tick
        elapsed+= epoch_time
        # To fasten the training process, we only print the result every save_freq epochs
        if (epoch+1)%save_freq == 0:
            train_loss_epoch_log.append(train_loss)
            regret = cspo_regret(reg, val_optmodel_list, loader_val)
            val_regret_log.append(regret)
            val_mse_loss, val_cspop_loss = cal_loss(reg, loader_val, val_optmodel_list,val_mse_loss_func,val_cspop_loss_func)
            val_mse_loss_log.append(val_mse_loss)
            val_cspop_loss_log.append(val_cspop_loss)
            time_elapsed_log.append(elapsed)
            # early_stopping(val_cspop_loss)
            # Update learning rate based on validation loss
            scheduler.step(regret)
            print("Epoch {:2},  Train Loss: {:9.4f}, Val MSE Loss: {:7.4f}, Val CSPOP Loss: {:7.4f}, Validation Regret: {:7.4f}%".format(epoch+1,
                                                                                                                                 train_loss, val_mse_loss, val_cspop_loss, regret*100))

            # Save the model with the best test loss
            if regret < min_val_regret:
                min_val_regret = regret
                best_model_state = deepcopy(reg.state_dict())

    print("Total Elapsed Time: {:.2f} Sec. \n".format(elapsed))

    return train_loss_log, train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, best_model_state, time_elapsed_log



# ------------------------------------------------------------------------------------------------
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=4, verbose=False, delta=1e-2, trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 4
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 5e-5
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.trace_func = trace_func

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0



# ------------------------------------------------------------------------------------------------
## Train Random Forest 
def cal_rf_cspop_loss(rf_list, loader_val, val_optmodel_list, cspop_loss_func):
    """
    Calculates CSPOP loss on the validation set for RF models.
    """

    val_cspop_loss = 0
    
    for i, data in enumerate(loader_val):
        batch_indices, x, c, w, z = data
        batch_indices = batch_indices.tolist()
        selected_models = [val_optmodel_list[i] for i in batch_indices]
        cp = torch.zeros_like(c)
        for j in range(c.shape[1]):
            cp[:,j] = torch.from_numpy(rf_list[j].predict(x))
        # cspo+ loss
        cspop_loss = cspop_loss_func(selected_models, cp, c, w, z)
        val_cspop_loss += cspop_loss.item()
    val_cspop_loss /= round(len(loader_val),4)
    return val_cspop_loss

def train_rf(method_name, train_x, train_c, val_x, val_c, importance_weights, val_optmodel_list, loader_val):
    """
    Train random forest for each item in the knapsack problem. Train d random forests.
    """

    train_loss_epoch_log = []
    val_regret_log = []
    val_mse_loss_log = []  
    val_cspop_loss_log = []
    val_cspop_loss_func = CSPOPlus(val_optmodel_list,processes=1)
    importance_weights_array = importance_weights.numpy().flatten()

    # Fit a random forest for each item in the knapsack problem
    num_items = train_c.shape[1]
    # ceil num_items/3
    max_features = math.ceil(num_items/3)
    rf_list = []
    cp_train = np.zeros_like(train_c)
    cp_val = np.zeros_like(val_c)
    for i in range(num_items):
        rf = RandomForestRegressor(n_estimators=100, max_features=max_features, random_state=i)
        if method_name == "rf":
            rf.fit(train_x, train_c[:,i])
        elif method_name == "is_rf":
            rf.fit(train_x, train_c[:,i], sample_weight=importance_weights_array)
        cp_train[:,i] = rf.predict(train_x)
        cp_val[:,i] = rf.predict(val_x)
        rf_list.append(rf)
    # Calculate Train MSE loss
    train_mse_loss = np.mean((cp_train - train_c)**2)
    train_loss_epoch_log.append(train_mse_loss)
    # Calculate regret on the validation set
    val_regret = cspo_regret_rf(rf_list, val_optmodel_list, loader_val)
    val_regret_log.append(val_regret)
    # Calculate validation MSE loss 
    val_mse_loss = np.mean((cp_val - val_c)**2)
    val_mse_loss_log.append(val_mse_loss)
    # Calculate validation CSPOP loss
    val_cspop_loss = cal_rf_cspop_loss(rf_list, loader_val, val_optmodel_list, val_cspop_loss_func)
    val_cspop_loss_log.append(val_cspop_loss)
    print("Epoch {:2},  Train Loss: {:9.4f}, Val MSE Loss: {:7.4f}, Val CSPOP Loss: {:7.4f}, Validation Regret: {:7.4f}%".format(-1,
                                                                                                                                 train_mse_loss, val_mse_loss, val_cspop_loss, val_regret*100))

    return train_loss_epoch_log, val_mse_loss_log, val_cspop_loss_log, val_regret_log, rf_list
    

