import argparse
from examples.utils.eiil_utils import get_transformed_waterbirds_to_eiil, split_data_opt, transform_envs_to_metadatafile
import configs.supported as supported
from models.initializer import initialize_model
from utils import ParseKwargs
import os
import torch
import pathlib
import pickle

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=bool, default=True)
    parser.add_argument('--dataset_name', type=str, default=True)
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument("--n_steps", type=int, default=10000)
    config = parser.parse_args()

    # Generate EIIL environements
    fname = pathlib.Path(config.log_dir) / f'notjoin_{config.dataset_name}_eiil.pickle'
    if not os.path.isfile(fname):
        print(f'------ Transform {config.dataset_name} to EIIL format ------')
        envs = get_transformed_waterbirds_to_eiil(config.dataset_name)
    else:
        print(f'------ Loading {config.dataset_name} to EIIL format ------')
        with open(fname, 'rb') as handle:
            envs = pickle.load(handle)

    print(f'------ Launch EIIL script to discover environements ------')
    # Load Model
    torch.cuda.empty_cache()
    model = initialize_model(config, 1)
    weight_file = os.path.join(config.log_dir, f'{config.dataset_name}_seed:{config.seed}_epoch:best_model.pth')
    checkpoint = torch.load(weight_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k, 'model.', ''): v for k, v in checkpoint['algorithm'].items()}
    print(f'------ Loading model from {weight_file} ------')
    model.load_state_dict(state_dict)
    model.cuda().train()
    pred_envs = split_data_opt(envs, model, batch_size=config.batch_size, join=False)

    # Save metadata file
    transform_envs_to_metadatafile(pred_envs, config.log_dir)


if __name__ == "__main__":
    main()
