import os

from connectivity_utils import transforms
from connectivity_utils import train
from connectivity_utils import data
import breeds, domainnet
VALID_BREEDS_DOMAINS = breeds.BREEDS_SPLITS_TO_FUNC.keys()
VALID_DOMAINNET_S_DOMAINS = domainnet.SENTRY_DOMAINS
VALID_DOMAINNET_DOMAINS = domainnet.VALID_DOMAINS

import argparse
parser = argparse.ArgumentParser(description='Test Connectivity of dataset single class',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
### data args ###
parser.add_argument('--dataset_name', type=str, required=True,
                    choices=['breeds', 'domainnet'],
                    help='Which dataset on which to test connectivity.')
parser.add_argument('--data_path', type=str,
                    help='Root path of the data.')
parser.add_argument('--domain', type=str, required=True,
                    help='Name of domain.')
parser.add_argument('--version', type=str,
                    choices=['full', 'sentry', 'source', 'target'], required=True,
                    help='For DomainNet, which version to use.')
### other args ###
parser.add_argument('--transform', type=str,
                    choices=['imagenet', 'simclr', 'swav-simclr'], default='swav-simclr',
                    help='Choice of data augmentation to use.')
parser.add_argument('--num_pairs', type=int, default=12,
                    help='For each connectivity metric, the number of pairs to randomly select.')
parser.add_argument('--pair_seed', type=int, default=12,
                    help='Seed for choosing pairs for each connectivity metric.')
### training args ####
parser.add_argument('-a', '--arch', default='resnet50', help='Architecture')
parser.add_argument('-j', '--workers', default=2, type=int, help='Number of workers')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('-b', '--batch-size', default=96, type=int, help='Batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, dest='lr', help='LR')
parser.add_argument('--momentum', default=0.9, type=float, help='Momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay', help='WD')
parser.add_argument('--save-freq', type=int, default=25, help='How often to save')
parser.add_argument('--print-freq', type=int, default=5, help='How often to print (in # batches)')
parser.add_argument('--save_model', action='store_true', help='Whether or not to save model checkpoints')
### if using frozen SwAV weights ###
parser.add_argument('--swav_dir', type=str,
                    help='If provided, will use the checkpoint in this directory.')
parser.add_argument('--swav_ckpt', type=int, default=399,
                    help='The epoch of the checkpoint to use.')
parser.add_argument('--sentry_filter', action='store_true',
                    help='If set, will filter out the small classes')
parser.add_argument('--sentry_ft', type=str,
                    help='If provided, will initialize the model from that SENTRY fine-tune checkpoint')

def main():
    args = parser.parse_args()
    # validate dataset args
    if args.dataset_name == 'breeds':
        if args.domain not in VALID_BREEDS_DOMAINS:
            raise ValueError(f'Valid Breeds domains are {VALID_BREEDS_DOMAINS} but received '
                             f'domain 1 as {args.domain}.')
        if args.version not in ['source', 'target']:
            raise ValueError('For breeds, the version must be either source or target.')
        num_classes = {
            'entity30': 30,
            'living17': 17
        }[args.domain]
    elif args.dataset_name == 'domainnet':
        if args.version not in ['full', 'sentry']:
            raise ValueError('For domainnet, the version must be either full or sentry.')
        valid = VALID_DOMAINNET_S_DOMAINS if args.version == 'sentry' else VALID_DOMAINNET_DOMAINS
        if args.domain not in valid:
            raise ValueError(f'Valid DomainNet domains for version {args.version} are {valid} but '
                             f'received domain {args.domain}.')
        num_classes = domainnet.NUM_CLASSES_DICT[args.version]
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset_name}.')
    # get transforms
    transform = transforms.get_transforms(args.transform)
    data.fill_default_data_path(args)
    # file saving code
    save_dir = os.path.join(
        'connectivity_checkpoints',
        'single-domain',
        args.dataset_name,
        f'{args.domain}-{args.version}-{args.transform}'
    )
    args.linear_probe_only = (args.swav_dir is not None) or (args.sentry_ft is not None)
    if args.linear_probe_only:
        if args.swav_dir is not None:
            save_dir += f'-{args.swav_dir.replace("/", "-")}-{args.swav_ckpt}'
        else:
            save_dir += f'-{args.sentry_ft.replace("/", "-")}'
    os.makedirs(save_dir, exist_ok=True)

    # get classes to compare
    class_pairs = train.get_classes_to_compare(num_classes, args.num_pairs, args.pair_seed)
    if args.sentry_filter:
        off_limits = data.get_domainnet_off_limits(VALID_DOMAINNET_S_DOMAINS)
        off_limits = off_limits[args.domain]
        off_limits = list(map(int, off_limits))
    for class_1, class_2 in class_pairs:
        print('*' * 10, f'Current class pair: {class_1} and {class_2}.', '*' * 10)
        if args.sentry_filter:
            if (class_1 in off_limits) or (class_2 in off_limits):
                print('Skipping because filtering by dataset size')
                continue
        train_ds, test_ds = data.get_class_datasets(
            args.dataset_name, args.data_path, args.domain, args.version, transform, class_1, class_2
        )
        identifier = f'classes-{class_1}-{class_2}'
        if args.save_model:
            identifier += '-save-model'
        train.main_loop(train_ds, test_ds, save_dir, identifier, args, save_model=args.save_model)

if __name__ == '__main__':
    main()
