from metaDatasetBaseline import * 
##################################
# A simple training dataset wrapper 
# for knowing the imageIDs
##################################
class TrainingDataset_Wrapper(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.loader = dataset.loader
        self.samples = dataset.samples
        self.transform = dataset.transform

    def __getitem__(self, index):
        image_path, target = self.samples[index]
        imageID = image_path.split('/')[-1].split('.')[0] # image_path = IMAGE_DATA_FOLDER + imageID + '.jpg'
        sample = self.loader(image_path)
        sample = self.transform(sample)
        # print('index', index, 'image_path', image_path)
        return sample, imageID, target

    def __len__(self):
        return len(self.samples)


def save_influence_results(subset_influence_batch_results, epoch, args):
    print('subset_influence_batch_results saved! epoch[{}]'.format(epoch) )
    with open(os.path.join( args.output_dir, 'subset_influence_results_epoch_{}.pkl'.format(epoch)), "wb") as pkl_file:
        pickle.dump(
            subset_influence_batch_results, 
            pkl_file, 
        )
    return

    
def subset_influence_load_data_and_train(model, criterion, optimizer, args):

    ##################################
    # EMR, out of domain val acc, containing other. 
    ##################################
    val_out_of_domain_loader, val_out_of_domain_dataset = get_val_loader(dataset_dir='val_out_of_domain', args=args)

    ##################################
    # EMR, val_sg folder, random split test set 
    ##################################
    val_valsg_loader, _ = get_val_loader(dataset_dir='val_valsg', args=args)


    ##################################
    # Create training dtaset 
    ##################################
    # Data loading code
    traindir = os.path.join(args.data, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_sampler = None
    train_dataset_wrap=TrainingDataset_Wrapper(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset_wrap, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, sampler=None)

    subset_influence_batch_results = {
        'sample_schedule': [], # append imageIDs 
        'target': [], # overwrite  
        'batch_results': [], # append results from each batch 
    }

    for epoch in range(args.start_epoch, args.epochs):
        for batch_id, train_batch in enumerate(train_loader):
            ##################################
            # One train step
            ##################################
            (images, ImageIDs, target) = train_batch
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(args.gpu, non_blocking=True)
            # compute output
            output = model(images)
            loss = criterion(output, target)
            # compute gradient and do SGD step
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            ##################################
            # Evaluation 
            ##################################
            print('epoch [{}] batch_id [{}]'.format(epoch, batch_id))
            logging.info('epoch [{}] batch_id [{}]'.format(epoch, batch_id))

            # print('ImageIDs', ImageIDs, 'args.batch_size', args.batch_size, 'target', target, 'images', images.shape)
            subset_influence_batch_results['sample_schedule'].append(
                ImageIDs # A list of str, e.g., ['150460', '1159344'] 
            )

            # evaluate on validation set
            print('out-of-domain val')
            logging.info('out-of-domain val')
            acc1, dump_result_dict = validate(val_out_of_domain_loader, model, criterion, args, dumpResult=False) # Not dumping results to make things faster 

            subset_influence_batch_results['batch_results'].append(dump_result_dict['pred_score_all'])
            subset_influence_batch_results['target'] = dump_result_dict['target_all']

            # Report every-group acc, worst-set acc 
            # report_every_set_acc(val_out_of_domain_dataset, args) # need to set dumpResult=True

        ##################################
        # Each Epoch: Save periodically or at the end
        ##################################
        save_influence_results(subset_influence_batch_results, epoch, args)

        ##################################
        # Each Epoch: test for random split val 
        ##################################
        print('random split val')
        logging.info('random split val')
        _ = validate(val_valsg_loader, model, criterion, args, dumpResult=False)


if __name__ == '__main__':

    main(fn_call_training=subset_influence_load_data_and_train)
