import os
import numpy as np

from connectivity_utils import transforms, train, data
import breeds, domainnet
VALID_BREEDS_DOMAINS = breeds.BREEDS_SPLITS_TO_FUNC.keys()
VALID_DOMAINNET_S_DOMAINS = domainnet.SENTRY_DOMAINS
VALID_DOMAINNET_DOMAINS = domainnet.VALID_DOMAINS

import argparse
parser = argparse.ArgumentParser(description='Distinguish domains',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
### data args ###
parser.add_argument('--dataset_name', type=str, required=True,
                    choices=['breeds', 'domainnet'],
                    help='Which dataset on which to test connectivity.')
parser.add_argument('--data_path', type=str,
                    help='Root path of the data.')
parser.add_argument('--domain_1', type=str, required=True,
                    help='Name of domain 1.')
parser.add_argument('--domain_2', type=str,
                    help='Name of domain 2 for domainnet.')
parser.add_argument('--version', type=str,
                    choices=['full', 'sentry'], default='sentry',
                    help='For domainnet, which version to use.')
parser.add_argument('--transform', type=str,
                    choices=['imagenet', 'simclr', 'swav-simclr'], default='swav-simclr',
                    help='Choice of data augmentation to use.')
### training args ##
parser.add_argument('-a', '--arch', default='resnet50', help='Architecture')
parser.add_argument('-j', '--workers', default=2, type=int, help='Number of workers')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('-b', '--batch_size', default=96, type=int, help='Batch size')
parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, dest='lr', help='Initial LR.')
parser.add_argument('--momentum', default=0.9, type=float, help='Momentum')
parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, dest='weight_decay', help='WD')
parser.add_argument('--save_freq', type=int, default=25, help='How often to save')
parser.add_argument('--print_freq', type=int, default=5, help='How often to print (in # batches)')
parser.add_argument('--continue_from', type=str, help='Which checkpoint to continue from')
parser.add_argument('--continue_from_epoch', type=int, default=-1, help='Which epoch to set the scheduler to')
### if using frozen SwAV weights ###
parser.add_argument('--swav_dir', type=str,
                    help='If provided, will use the checkpoint in this directory.')
parser.add_argument('--swav_ckpt', type=int, default=399,
                    help='The epoch of the checkpoint to use.')

def main():
    args = parser.parse_args()
    # validate dataset args
    if args.dataset_name == 'breeds':
        args.domain_2 = args.domain_1
        if args.domain_1 not in VALID_BREEDS_DOMAINS:
            raise ValueError(f'Valid Breeds domains are {VALID_BREEDS_DOMAINS} but received '
                             f'domain 1 as {args.domain_1}.')
    elif args.dataset_name == 'domainnet':
        if args.version not in ['full', 'sentry']:
            raise ValueError('For domainnet, the version must be either full or sentry.')
        if args.domain_1 == args.domain_2:
            raise ValueError('Cannot provide the same domain for 1 and 2.')
        valid = VALID_DOMAINNET_S_DOMAINS if args.version == 'sentry' else VALID_DOMAINNET_DOMAINS
        if (args.domain_1 not in valid) or (args.domain_2 not in valid):
            raise ValueError(f'Valid DomainNet domains for version {args.version} are {valid} but '
                             f'received domain {args.domain}.')
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset_name}.')
    # get transforms
    transform = transforms.get_transforms(args.transform)
    data.fill_default_data_path(args)
    # file saving code
    save_dir = os.path.join(
        'data_selection',
        args.dataset_name,
        f'{args.domain_1}-{args.domain_2}-{args.version}-{args.transform}'
    )
    args.linear_probe_only = args.swav_dir is not None
    if args.linear_probe_only:
        print('Warning: this code should be updated')
        save_dir += f'-{args.swav_dir.replace("/", "-")}-{args.swav_ckpt}'
    os.makedirs(save_dir, exist_ok=True)

    train_ds, test_ds = data.get_domain_datasets(
        args.dataset_name, args.data_path, args.domain_1, args.domain_2,
        args.version, transform
    )

    identifier = 'result'
    train.main_loop(train_ds, test_ds, save_dir, identifier, args,
                    save_model=True, mpd=True)

if __name__ == '__main__':
    main()
