import os
import torch
from torch import nn
import numpy as np
import argparse
from collections import defaultdict
import json
from glasso import *
from utils import *
from model import DNNBaseline, AdditiveModel, train, eval_model
from datasets import generate_dataset

def main():

	parser = argparse.ArgumentParser(description="Train a model")
	parser.add_argument('--data_path', default = "/home/users/yhung7/SDAM/data/", type=str, help='Path to the dataset')
	parser.add_argument('--output_dir', default = "/home/users/yhung7/SDAM/src/output/", type=str, help='Path to save the output')
	parser.add_argument('--model_name', default='sdam', type=str, help='Name of the model to use')
	parser.add_argument('--do_eval', action='store_true', help='Whether to evaluate the model after training')
	parser.add_argument('--data_name', default='main_effect', type=str, help='type of the data to use')
	parser.add_argument('--model_idx', type=int, default=0, help='Index of the model to use')
	parser.add_argument('--repeat', type=int, default=1, help='Number of repetitions for the experiment')
	parser.add_argument('--nlist', nargs='*', type=int, default=[150, 50, 1000], help='Sizes of training, validation, and test datasets')
	parser.add_argument('--nfeature', type=int, default= 150, help='Number of features in the dataset')
	parser.add_argument('--UB', type=float, default=2.5, help='Upper bound for the dataset features')
	parser.add_argument('--LB', type=float, default=-2.5, help='Lower bound for the dataset features')
	parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help='Device to use for training')
	
	
	# Simulation Settings
	parser.add_argument(
		"--noise",
		default = 1.0,
		type = float,
		help = "Noise level for the model",
	)
	
	# SAM Args
	parser.add_argument('--SAM_iter', type=int, default=50, help='Number of iterations for SAM')
	parser.add_argument('--SAM_tol', type=float, default=1e-6, help='Tolerance for convergence in SAM')
	parser.add_argument('--SAM_ftol', type=float, default=1e-3, help='Tolerance for feature selection in SAM')
	parser.add_argument('--knots', type=int, default=10, help='Number of knots for spline basis')
	parser.add_argument('--nbound', type=int, default=3, help='Number of boundary knots for spline basis')
	parser.add_argument('--degree', type=int, default=3, help=' Degree of the spline basis')
	parser.add_argument('--plot', action='store_true', help='Plot the figure for the component functions')
	parser.add_argument('--glasso_threshold', type=float, default=0.2, help='Threshold for group lasso selection')
	parser.add_argument('--gknots', type=int, default=5, help='Number of knots for spline basis for group lasso')
	parser.add_argument('--giknots', type=int, default=5, help='Number of knots for inter-spline basis for group lasso')
	
	# Model Args
	parser.add_argument('--hidden_config', nargs='*', type=int, default=[8, 6, 3], help='model architecture')
	parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model')
	parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate for the model')
	parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the optimizer')
	
	# Training Args
	parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
	parser.add_argument('--eval_batch_size', type=int, default=64, help='Batch size for evaluation')
	parser.add_argument('--patience', type=int, default=50, help='Early stopping patience')
	parser.add_argument('--epochs', type=int, default=10000, help='Number of training epochs')
	parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay for the optimizer')
	parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
	args = parser.parse_args()
	
	args_arc = f"{args.model_name}-{args.data_name}-"
	args.log_file = os.path.join(args.output_dir, args_arc + ".txt")
	
	with open(args.log_file, "a") as f:
		f.write(str(args) + "\n")
		
	checkpoint = args_arc + str(args.nlist[0]) + ".pt"
	args.checkpoint_path = os.path.join(args.output_dir, checkpoint)
	
	#Dataset = MyDataset(args.data_path)
	active_dict = {}
	results = defaultdict(list)
	MSPE_dict = {}
	Runtime_dict = {}
	MSPE_dict['DNN'] = torch.zeros(args.repeat)
	Runtime_dict['DNN'] = torch.zeros(args.repeat)
	
	
	Dataset = torch.load('/home/users/yhung7/SDAM/data/'+args.data_name+'_data.pt', weights_only = True)
	X_train, y_train =  Dataset['X_train'], Dataset['y_train']
	X_val, y_val =  Dataset['X_valid'], Dataset['y_valid']
	X_test, y_test =  Dataset['X_test'], Dataset['y_test']
	
	args.repeat = X_train.size()[0]
	
	MSPE_dict['ADN'] = torch.zeros(args.repeat)
	Runtime_dict['ADN'] = torch.zeros(args.repeat)
	
	for r in range(args.repeat):
		torch.manual_seed(args.seed + r)
				
		Train_X, Train_y =  X_train[r], y_train[r]
		Valid_X, Valid_y =  X_val[r], y_val[r]
		Test_X, Test_y =  X_test[r], y_test[r]
	
	
	
		################
		#    Stage I   #
		################
	
		# Run SAM for each alpha value
		print("Running SAM for different alpha values")
	
		alpha_list = torch.tensor(list(torch.arange(0, 1.1, 0.1)) + [1.5, 2, 3, 5])
		
		results_I = train_SAM(Train_X, Train_y, alpha_list, max_iter = args.SAM_iter, nk = args.knots, nb = args.nbound, custom = False)
	
	
		####################
		#    Group-LASSO   #
		####################
	
		opt_var = results_I['opt_var']
		opt_df = extract_active_features(Train_X, opt_var)
		interactions = list(combinations(list(opt_df.keys()), 2))
		selector = AdditiveInteractionSelector(n_splines = args.gknots, interaction_splines=args.giknots)
		
		try:
			selector.fit(opt_df, Train_y, interactions=interactions)
	
		except Exception as e:
			selector.fit(opt_df, Train_y, interactions=interactions, HAS_GROUP_LASSO=False)
	
		SDAM_config = selector.get_important_groups(threshold=args.glasso_threshold)
	
		maine = []
		intere = []
		for i in SDAM_config:
			if len(i) == 1:
				maine.append(i)
			else:
				intere.extend(i)
				
		if intere != []:
			maine.append(list(set(intere)))
		else:
			pass
		
		SDAM_config = maine
		print(SDAM_config)
		active_dict[r] = SDAM_config
		print(f"Rep {r+1} Active set:  {active_dict[r]}")
		
		if active_dict[r] == []:
			continue
	
		##################
		#   Final Stage  #
		##################
	
	
		# ADNN Configuration
		ADNN = AdditiveModel(
			index_list=active_dict[r],
			hidden_dims=args.hidden_config,
			output_dim=1
			)
		
		if args.do_eval:
			pass
	
		else:
			root = '/home/users/yhung7/SDAM/src/save_model/'
			print("Repetition ", r+1)
	
			print('Training ADNN...')
			Runtime_dict['ADN'][r] = train(ADNN, Train_X, Train_y.view(-1, 1), Valid_X, Valid_y.view(-1, 1), root+args.data_name+'ADN'+str(r+1)+'.pth', n_epochs = args.epochs, batch_size=args.batch_size, lr = args.lr)
			MSPE_dict['ADN'][r] = eval_model(ADNN, root+args.data_name+'ADN'+str(r+1)+'.pth', Test_X, Test_y.view(-1, 1))
	
	results['Runtime_ADN'] = Runtime_dict['ADN']
	results['MSE_ADN']= MSPE_dict['ADN']
	
	result_ = {k: v.tolist() for k, v in results.items()}
	result_log_file = os.path.join(args.output_dir, args_arc +'n' + str(args.nlist[0]) + ".json")
	print(torch.mean(results['MSE_ADN']), torch.std(results['MSE_ADN']))
	with open(result_log_file, "w") as f:
		json.dump(result_, f)
	

if __name__ == "__main__":
    main()







