

# import several standard libraries
import os
import sys
import torch
import pickle
import argparse


# import the data, model, and helper functions
from data import WB, CelebA,  multiNLI
from model import resnet_model, BERT_model
from helpers import set_seed
from training import train_model, loss_BCE, train_BERT
from helpers import str_to_bool, Convert

# import the necessary libraries from torch
import transformers
import torch.optim as optim

def main(dataset, dataset_file, model_name, lr, weight_decay, batch_size, n_epochs, early_stopping, optimizer_type, patience, device_type, seeds, workers=0, y_dim=1, data_folder='data/cleaned', model_folder='models/CivilComments', save_at_steps=None, lr_schedule=None):

    # set the device
    device = torch.device(device_type)
    dataset_file_full = '{}/{}'.format(data_folder, dataset_file)

    if dataset == 'WB':

        # define the CUB dataset object, and load data
        data_obj = WB()
        data = data_obj.load_data(dataset_file_full)
        data_obj.set_data_attributes(data['X_train'], data['y_train'], data['X_val'], data['y_val'], data['X_test'],  data['y_test'], device)

        # create the loaders
        data_obj.create_loaders(batch_size, workers, shuffle=True, include_weights=False, train_weights = None, val_weights = None, pin_memory=False)
    
    elif dataset == 'CelebA':
        
        # define the CelebA dataset object, and load data
        data_obj = CelebA()
        data_obj.create_loaders(batch_size, workers, shuffle=True, pin_memory=False, h5_file_path=dataset_file_full, 
                        x_key_train='X_train', y_key_train='y_train', x_key_val='X_val', y_key_val='y_val', device=device)
    
    
    elif dataset == 'multiNLI':

        data_obj = multiNLI()
        data_obj.load_tokens('data/Cleaned/multiNLI')
        data_obj.create_loaders(batch_size, shuffle=True, workers=workers, pin_memory=True, include_test=False)

   

    # go over each seed
    for seed in seeds:

        # train the model
        set_seed(seed)
        
        # define the model
        if dataset == 'WB' or dataset == 'CelebA':
            model_obj = resnet_model(model_name, y_dim)
        elif dataset == 'CivilComments' or dataset == 'multiNLI':
            model_obj = BERT_model(y_dim, output_attentions=False, output_hidden_states=False)
            
           

        # put the model on the device
        model_obj.to(device)

        # load the following model 

         # define the optimizer
        if optimizer_type == 'Adam':
            optimizer = optim.Adam(model_obj.parameters(), lr=lr, weight_decay=weight_decay)
        elif optimizer_type == 'AdamW':
            if dataset == 'multiNLI':
                optimizer = optim.AdamW(list(model_obj.parameters()), lr=lr, weight_decay=weight_decay)
        elif optimizer_type == 'SGD':
            optimizer = optim.SGD(model_obj.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)


        if lr_schedule is None:
            scheduler = None
        elif lr_schedule == 'BERT_lr_schedule':
            scheduler = transformers.get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=n_epochs)


        # train the model
        if early_stopping:
            save_best_model=True
        else:
            save_best_model=False

              
        
        print('Starting training:')
        set_seed(seed)
        if dataset == 'WB' or dataset == 'CelebA':
            train_model(n_epochs, model_obj, data_obj.dict_loaders, 'train', 'val', device, optimizer, loss_BCE, per_step=1, early_stopping=early_stopping, orig_patience=patience, tol = 0.001, save_best_model=save_best_model, model_name=model_name,  verbose=True, save_at_steps=save_at_steps, scheduler=scheduler)
        elif dataset == 'CivilComments' or dataset == 'multiNLI':
            train_BERT(n_epochs, model_obj,  data_obj.dict_loaders, 'train', 'val', device, optimizer, loss_BCE, per_step=1, early_stopping=early_stopping, orig_patience=patience, tol = 0.001, save_best_model=save_best_model, model_name=model_name,  verbose=True, save_at_steps=save_at_steps, scheduler=scheduler)

        # save the model parameters via state_dict
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
        param_dict = {'lr': lr, 'weight_decay': weight_decay, 'model_name': model_name, 'seed': seed, 'n_epochs': n_epochs, 'batch_size': batch_size, 'optimizer_type': optimizer_type, 'patience':patience,  'early_stopping': early_stopping}
        model_file_name = '{}_model_seed_{}.pt'.format(dataset, seed)

        # save the model and its parameters
        torch.save(model_obj.state_dict(), model_folder + '/' + model_file_name)
        pickle.dump(param_dict, open(model_folder + '/{}_model_seed_ {}.pkl'.format(dataset, seed), 'wb'))







if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dataset preparation')
    parser.add_argument('--dataset', type=str, help='The dataset to use')
    parser.add_argument('--dataset_file', type=str, help='The .pkl file containing the dataset')
    parser.add_argument('--model_name', type=str, default='resnet50', help='The name of the model to use')
    parser.add_argument('--lr', type=float, default=0.001, help='The learning rate for the model')
    parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay for the model')
    parser.add_argument('--batch_size', type=int, default=32, help='The batch size for the model')
    parser.add_argument('--n_epochs', type=int, default=10, help='The number of epochs to train the model')
    parser.add_argument('--early_stopping', type=str, default='true', help='Whether to use early stopping')
    parser.add_argument('--optimizer_type', type=str, default='adam', help='The type of optimizer to use')
    parser.add_argument('--patience', type=int, help='patience')
    parser.add_argument('--device_type', type=str, default='cuda', help='The type of device to use')
    parser.add_argument('--seeds', type=str, default='0', help='seed')
    parser.add_argument('--data_folder', default='data/Cleaned', help= 'folder to get data from')
    parser.add_argument('--model_folder',default='models/CUB', help= 'folder to save models')
    parser.add_argument('--save_at_steps', default=None, help='steps to save at')
    parser.add_argument('--lr_schedule', default=None, help='lr schedule')



    args = parser.parse_args()
   

    
    
    args.early_stopping = str_to_bool(args.early_stopping)
    args.seeds = Convert(args.seeds, int)
    if args.save_at_steps is not None:
        args.save_at_steps = Convert(args.save_at_steps, int)

 
   

    # Run the main function
    main(args.dataset, args.dataset_file, args.model_name, args.lr, args.weight_decay, args.batch_size, args.n_epochs, args.early_stopping, args.optimizer_type, args.patience, args.device_type, args.seeds, 
         data_folder=args.data_folder, model_folder=args.model_folder, save_at_steps=args.save_at_steps, lr_schedule=args.lr_schedule)

   

