

# import the necessary libraries
import os
import sys
import pickle
import numpy as np
from collections import defaultdict
import argparse
from PIL import Image
import pandas as pd
import sys


# import the data, model, and helper functions
from data import WB, CelebA, multiNLI
from model import resnet_model,  get_predictions_in_batches, BERT_model
from helpers import set_seed
from training import train_model, loss_BCE

# import the necessary libraries from torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms


def main(dataset, dataset_file, model_name, model_folder, device_type, seed, workers=0, y_dim=1, batch_size=32):


    # set the device
    device = torch.device(device_type)

    # determine the data_obj
    if dataset == 'WB':
        data_obj = WB()

        # load the data
        with open(dataset_file, 'rb') as f:
            data = pickle.load(f)

        # create group variables
        g_train = data_obj.create_g(data['y_train'], data['c_train'])
        g_val = data_obj.create_g(data['y_val'], data['c_val'])
        g_test = data_obj.create_g(data['y_test'], data['c_test'])

        # set the data attributes
        data_obj.set_data_attributes(data['X_train'], data['y_train'], data['X_val'], data['y_val'], data['X_test'],  data['y_test'], device)

        # create the loaders
        set_seed(seed)
        data_obj.create_loaders(batch_size, workers, shuffle=False, include_weights=False, train_weights = None, val_weights = None, pin_memory=True, include_test=True)
    
    elif dataset == 'CelebA':
        data_obj = CelebA()

        # load the y and c values
        y_train, c_train, y_val, c_val, y_test, c_test = data_obj.load_y_c(dataset_file)

        # create group variables
        g_train = data_obj.create_g(y_train, c_train)
        g_val = data_obj.create_g(y_val, c_val)
        g_test = data_obj.create_g(y_test, c_test)

        # set the y & g values of the data object
        data_obj.y_train = y_train
        data_obj.y_val = y_val
        data_obj.y_test = y_test
        data_obj.g_train = g_train
        data_obj.g_val = g_val
        data_obj.g_test = g_test

        # create the loaders
        data_obj.create_loaders(batch_size, workers, shuffle=False, pin_memory=True, h5_file_path=dataset_file, 
                        x_key_train='X_train', y_key_train='y_train', x_key_val='X_val', y_key_val='y_val', x_key_test='X_test', y_key_test='y_test',
                          device=device, include_test=True)
    
    elif dataset == 'multiNLI':
        data_obj = multiNLI()
        data_obj.load_tokens('data/Cleaned/multiNLI')
        data_obj.create_loaders(batch_size, shuffle=True, workers=workers, pin_memory=True, include_test=True)
    


    
    # get the model
    if dataset == 'WB':
        model_file_name = '{}/WB_model_seed_{}.pt'.format(model_folder, seed)
    elif dataset == 'CelebA':
        model_file_name = '{}/CelebA_model_seed_{}.pt'.format(model_folder, seed)
    elif dataset == 'multiNLI':
        model_file_name = '{}/multiNLI_model_seed_{}.pt'.format(model_folder, seed)

    # check if the model exists
    print('model_file_name: ', model_file_name)
    if os.path.exists('models/'+ model_file_name):
        if dataset == 'WB' or dataset == 'CelebA':
            model_obj = resnet_model(model_name, y_dim)
            model_obj.load_state_dict(torch.load('models/' + model_file_name, map_location=device))
            model_obj.to(device)
        elif dataset == 'multiNLI':
            model_obj = BERT_model(y_dim, output_attentions=False, output_hidden_states=False)
            model_obj.load_state_dict(torch.load('models/' + model_file_name, map_location=device))
            model_obj.to(device)
    else:
        # throw an error if the model does not exist
        raise ValueError('The model does not exist')
    
    # create the loaders for the embeddings
    train_loader = data_obj.dict_loaders['train']
    val_loader = data_obj.dict_loaders['val']
    test_loader = data_obj.dict_loaders['test']

    # this ensures these layers are not trained subsequently
    for param in model_obj.parameters(): 
        param.requires_grad = False

    # get the embeddings
    train_pred = get_predictions_in_batches( model_obj, train_loader, device)
    val_pred = get_predictions_in_batches( model_obj, val_loader, device)
    test_pred = get_predictions_in_batches( model_obj, test_loader, device)

    # save the embeddings in a folder
    # the folder corresponds to the embeddings/model_file_name
    if not os.path.exists('embeddings/' + model_file_name.replace('.pt', '')):
        os.makedirs('embeddings/' + model_file_name.replace('.pt', ''))


    # create a dictionary to save the embeddings, y values, and g values
    pred = {'train': train_pred, 'val': val_pred, 'test': test_pred}
    
    # save the embeddings in .pt files
    torch.save(pred, 'embeddings/' + model_file_name.replace('.pt', '') + '/pred.pt')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dataset preparation')
    parser.add_argument('--dataset', type=str, help='The dataset to use')
    parser.add_argument('--dataset_file', type=str, help='The .pkl file containing the dataset')
    parser.add_argument('--model_name', type=str, default='resnet50', help='The name of the model to use')
    parser.add_argument('--model_folder', type=str, default='', help='The folder to save the model')
    parser.add_argument('--device_type', type=str, default='mps', help='The type of device to use')
    parser.add_argument('--seed', type=int, default=42, help='The random seed to use')


    args = parser.parse_args()
 
   

    # Run the main function
    main(args.dataset, args.dataset_file, args.model_name, args.model_folder,  args.device_type, args.seed)

   

