import os
import torch
from torch import nn
import numpy as np
import argparse
from collections import defaultdict
import json
from glasso import *
from utils import *
from model import DNNBaseline, AdditiveModel, train, eval_model
from datasets import generate_dataset

def main():
    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--data_path', default = "../data/", type=str, help='Path to the dataset')
    parser.add_argument('--output_dir', default = "output/", type=str, help='Path to save the output')
    parser.add_argument('--model_name', default='SDAMI', type=str, help='Name of the model to use')
    parser.add_argument('--do_eval', action='store_true', help='Whether to evaluate the model after training')
    parser.add_argument('--data_name', default='main_effect', type=str, help='type of the data to use')
    parser.add_argument('--model_idx', type=int, default=0, help='Index of the model to use')
    parser.add_argument('--repeat', type=int, default=1, help='Number of repetitions for the experiment')
    parser.add_argument('--nlist', nargs='*', type=int, default=[150, 30, 500], help='Sizes of training, validation, and test datasets')
    parser.add_argument('--nfeature', type=int, default= 100, help='Number of features in the dataset')
    parser.add_argument('--UB', type=float, default=2.5, help='Upper bound for the dataset features')
    parser.add_argument('--LB', type=float, default=-2.5, help='Lower bound for the dataset features')
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help='Device to use for training')


    # Simulation Settings
    parser.add_argument(
        "--noise",
        default = 1.0,
        type = float,
        help = "Noise level for the model",
    )

    # SAM Args
    parser.add_argument('--SAM_iter', type=int, default=100, help='Number of iterations for SAM')
    parser.add_argument('--SAM_tol', type=float, default=1e-6, help='Tolerance for convergence in SAM')
    parser.add_argument('--SAM_ftol', type=float, default=1e-3, help='Tolerance for feature selection in SAM')
    parser.add_argument('--knots', type=int, default=10, help='Number of knots for spline basis')
    parser.add_argument('--nbound', type=int, default=3, help='Number of boundary knots for spline basis')
    parser.add_argument('--degree', type=int, default=3, help=' Degree of the spline basis')
    parser.add_argument('--plot', action='store_true', help='Plot the figure for the component functions')
    parser.add_argument('--glasso_threshold', type=float, default=0.2, help='Threshold for group lasso selection')
    parser.add_argument('--gknots', type=int, default=5, help='Number of knots for spline basis for group lasso')
    parser.add_argument('--giknots', type=int, default=5, help='Number of knots for inter-spline basis for group lasso')

    # Model Args
    parser.add_argument('--hidden_config', nargs='*', type=int, default=[15, 12, 10], help='model architecture')
    parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate for the model')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the optimizer')

    # Training Args
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
    parser.add_argument('--eval_batch_size', type=int, default=64, help='Batch size for evaluation')
    parser.add_argument('--patience', type=int, default=50, help='Early stopping patience')
    parser.add_argument('--epochs', type=int, default=10000, help='Number of training epochs')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay for the optimizer')
    parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
    args = parser.parse_args()

    args_arc = f"{args.model_name}-{args.data_name}-n{args.nlist[0]}"
    args.log_file = os.path.join(args.output_dir, args_arc + ".txt")

    with open(args.log_file, "a") as f:
        f.write(str(args) + "\n")
        
    checkpoint = args_arc + ".pt"
    args.checkpoint_path = os.path.join(args.output_dir, checkpoint)

    #Dataset = MyDataset(args.data_path)
    active_dict = {}
    results = defaultdict(list)
    MSPE_dict = {}
    MSPE_dict['ADN'] = torch.zeros(args.repeat)
    MSPE_dict['DNN'] = torch.zeros(args.repeat)
    Runtime_dict = {}
    Runtime_dict['ADN'] = torch.zeros(args.repeat)
    Runtime_dict['DNN'] = torch.zeros(args.repeat)



    
    ## Store all the data 
    x_train_stack = []; x_valid_stack = []; x_test_stack = []
    y_train_stack = []; y_valid_stack = []; y_test_stack = []

    for r in range(args.repeat):
        torch.manual_seed(args.seed + r)

        example = ['only_main', 'weak_main', 'inter_no_overlap', 'inter_mild_overlap', 'inter_strong_overlap', 'only_inter']
        if args.data_name in example:
            print("Generating "+args.data_name+" dataset")
            Dataset = generate_dataset(args.nlist, args.nfeature, args.UB, args.LB, args.data_name)
        else:
            raise ValueError("Unknown data name")
        
        Train_X, Train_y, _ =  Dataset['Train']['data'], Dataset['Train']['target'], Dataset['Train']['true_func']
        Valid_X, Valid_y, _ =  Dataset['Valid']['data'], Dataset['Valid']['target'], Dataset['Valid']['true_func']
        Test_X, Test_y, _ =  Dataset['Test']['data'], Dataset['Test']['target'], Dataset['Test']['true_func']

        x_train_stack.append(Train_X); x_valid_stack.append(Valid_X); x_test_stack.append(Test_X)
        y_train_stack.append(Train_y); y_valid_stack.append(Valid_y); y_test_stack.append(Test_y)
    
    x_train_stack = torch.stack(x_train_stack, dim = 0)
    y_train_stack = torch.stack(y_train_stack, dim = 0)
    x_valid_stack = torch.stack(x_valid_stack, dim = 0)
    y_valid_stack = torch.stack(y_valid_stack, dim = 0)
    x_test_stack = torch.stack(x_test_stack, dim = 0)
    y_test_stack = torch.stack(y_test_stack, dim = 0)
    data_dict = {'X_train': x_train_stack, 'y_train': y_train_stack, 'X_valid': x_valid_stack, 'y_valid': y_valid_stack, 'X_test': x_test_stack, 'y_test': y_test_stack}
    save_path = f"{args.data_path}{args.data_name}-simulation_results.pt"  # specify your desired path
    torch.save(data_dict, save_path)

    del x_train_stack, y_train_stack, x_valid_stack, y_valid_stack, x_test_stack, y_test_stack, Train_X, Train_y, Valid_X, Valid_y, Test_X, Test_y 
    
    
    for r in range(args.repeat):
        torch.manual_seed(args.seed + r)
        #n_samples = args.nlist[0]

        example = ['only_main', 'weak_main', 'inter_no_overlap', 'inter_mild_overlap', 'inter_strong_overlap', 'only_inter']
        if args.data_name in example:
            print("Generating "+args.data_name+" dataset")
            Dataset = generate_dataset(args.nlist, args.nfeature, args.UB, args.LB, args.data_name)
        else:
            raise ValueError("Unknown data name")
        
        Train_X, Train_y, _ =  Dataset['Train']['data'], Dataset['Train']['target'], Dataset['Train']['true_func']
        Valid_X, Valid_y, _ =  Dataset['Valid']['data'], Dataset['Valid']['target'], Dataset['Valid']['true_func']
        Test_X, Test_y, _ =  Dataset['Test']['data'], Dataset['Test']['target'], Dataset['Test']['true_func']


        ################
        #    Stage I   #
        ################

        # Run SAM for each alpha value
        print("Running SAM for different alpha values")

        alpha_list = torch.tensor(list(torch.arange(0, 1.1, 0.1)) + [1.5, 2, 3, 5])
        
        results_I = train_SAM(Train_X, Train_y, alpha_list, max_iter = args.SAM_iter, nk = args.knots, nb = args.nbound, custom = False)
        '''
        if args.plot:
            plot_comp_norm(results_I, args.output_dir, args_arc)
            plot_component(results_I, true_func, args.output_dir, args_arc)
        opt_comp_I = results_I['component'][results_I['opt_loc']]
        '''


        ####################
        #    Group-LASSO   #
        ####################

        opt_var = results_I['opt_var']
        opt_df = extract_active_features(Train_X, opt_var)
        interactions = list(combinations(list(opt_df.keys()), 2))
        selector = AdditiveInteractionSelector(n_splines = args.gknots, interaction_splines=args.giknots)
        
        try:
            selector.fit(opt_df, Train_y, interactions=interactions)

        except Exception as e:
            selector.fit(opt_df, Train_y, interactions=interactions, HAS_GROUP_LASSO=False)

        SDAM_config = selector.get_important_groups(threshold=args.glasso_threshold)

        active_dict[r] = SDAM_config
        print(f"Rep {r+1} Active set:  {active_dict[r]}")


        ##################
        #   Final Stage  #
        ##################

        DNN = DNNBaseline(input_dim=Train_X.size(1), hidden_dims = args.hidden_config)


        # ADNN Configuration
        ADNN = AdditiveModel(
            index_list=active_dict[r],
            hidden_dims=args.hidden_config,
            output_dim=1
            )
        
        if args.do_eval:
            pass

        else:
            root = './save_model/'
            print("Repetition ", r+1)
            print("Training DNN...")
            Runtime_dict['DNN'][r] = train(DNN, Train_X, Train_y.view(-1, 1), Valid_X, Valid_y.view(-1, 1), root+'DNN'+str(r+1)+'.pth', n_epochs = args.epochs, batch_size=args.batch_size, lr = args.lr)
            MSPE_dict['DNN'][r] = eval_model(DNN, root+'DNN'+str(r+1)+'.pth', Test_X, Test_y.view(-1, 1))

            print('Training ADNN...')
            Runtime_dict['ADN'][r] = train(ADNN, Train_X, Train_y.view(-1, 1), Valid_X, Valid_y.view(-1, 1), root+'ADN'+str(r+1)+'.pth', n_epochs = args.epochs, batch_size=args.batch_size, lr = args.lr)
            MSPE_dict['ADN'][r] = eval_model(ADNN, root+'ADN'+str(r+1)+'.pth', Test_X, Test_y.view(-1, 1))

    results['Runtime_DNN']= Runtime_dict['DNN']
    results['Runtime_ADN'] = Runtime_dict['ADN']
    results['MSE_DNN']= MSPE_dict['DNN']
    results['MSE_ADB']= MSPE_dict['ADN']

    result_ = {k: v.tolist() for k, v in results.items()}
    result_log_file = os.path.join(args.output_dir, args_arc + ".json")
    result_feature_file = os.path.join(args.output_dir, args_arc + "-feature.json")

    with open(result_log_file, "w") as f:
        json.dump(result_, f)
    
    with open(result_feature_file, "w") as f:
        json.dump(active_dict, f)


if __name__ == "__main__":
    main()







