import os
import numpy as np

from connectivity_utils import transforms
from connectivity_utils import train
from connectivity_utils import data
import breeds, domainnet
VALID_BREEDS_DOMAINS = breeds.BREEDS_SPLITS_TO_FUNC.keys()
VALID_DOMAINNET_S_DOMAINS = domainnet.SENTRY_DOMAINS
VALID_DOMAINNET_DOMAINS = domainnet.VALID_DOMAINS

import argparse
parser = argparse.ArgumentParser(description='Test Connectivity of dataset between domains',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
### data args ###
parser.add_argument('--dataset_name', type=str, required=True,
                    choices=['breeds', 'domainnet'],
                    help='Which dataset on which to test connectivity.')
parser.add_argument('--data_path', type=str,
                    help='Root path of the data.')
parser.add_argument('--domain_1', type=str, required=True,
                    help='Name of domain 1.')
parser.add_argument('--domain_2', type=str,
                    help='Name of domain 2 for domainnet.')
parser.add_argument('--version', type=str,
                    choices=['full', 'sentry'], default='sentry',
                    help='For domainnet, which version to use.')
### other args ###
parser.add_argument('--transform', type=str,
                    choices=['imagenet', 'simclr', 'swav-simclr'], default='swav-simclr',
                    help='Choice of data augmentation to use.')
parser.add_argument('--num_pairs', type=int, default=12,
                    help='For each connectivity metric, the number of pairs to randomly select.')
parser.add_argument('--pair_seed', type=int, default=14,
                    help='Seed for choosing pairs for each connectivity metric.')
parser.add_argument('--single_seed', type=int, default=16,
                    help='For selecting single classes to compare.')
### training args ##
parser.add_argument('-a', '--arch', default='resnet50', help='Architecture')
parser.add_argument('-j', '--workers', default=2, type=int, help='Number of workers')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('-b', '--batch-size', default=96, type=int, help='Batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, dest='lr', help='LR')
parser.add_argument('--momentum', default=0.9, type=float, help='Momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay', help='WD')
parser.add_argument('--save-freq', type=int, default=25, help='How often to save')
parser.add_argument('--print-freq', type=int, default=5, help='How often to print (in # batches)')
parser.add_argument('--save_model', action='store_true')
### if using frozen SwAV weights ###
parser.add_argument('--swav_dir', type=str,
                    help='If provided, will use the checkpoint in this directory.')
parser.add_argument('--swav_ckpt', type=int, default=399,
                    help='The epoch of the checkpoint to use.')
parser.add_argument('--sentry_ft', type=str,
                    help='If provided, will use the SENTRY checkpoint')
### other args ###
parser.add_argument('--same_class_only', action='store_true', help='Whether to compute only for same class.')
parser.add_argument('--range_start', type=int, help='Minimum class label to use (inclusive).')
parser.add_argument('--range_end', type=int, help='Maximum class label to use (uninclusive).')
parser.add_argument('--sentry_filter', action='store_true',
                    help='If set, will filter out the small classes')

def main():
    args = parser.parse_args()
    # validate dataset args
    if args.dataset_name == 'breeds':
        args.domain_2 = args.domain_1
        if args.domain_1 not in VALID_BREEDS_DOMAINS:
            raise ValueError(f'Valid Breeds domains are {VALID_BREEDS_DOMAINS} but received '
                             f'domain 1 as {args.domain_1}.')
        num_classes = {
            'entity30': 30,
            'living17': 17
        }[args.domain_1]
    elif args.dataset_name == 'domainnet':
        if args.version not in ['full', 'sentry']:
            raise ValueError('For domainnet, the version must be either full or sentry.')
        if args.domain_1 == args.domain_2:
            raise ValueError('Cannot provide the same domain for 1 and 2.')
        valid = VALID_DOMAINNET_S_DOMAINS if args.version == 'sentry' else VALID_DOMAINNET_DOMAINS
        if (args.domain_1 not in valid) or (args.domain_2 not in valid):
            raise ValueError(f'Valid DomainNet domains for version {args.version} are {valid} but '
                             f'received domain {args.domain}.')
        num_classes = domainnet.NUM_CLASSES_DICT[args.version]
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset_name}.')
    # get transforms
    transform = transforms.get_transforms(args.transform)
    data.fill_default_data_path(args)
    # file saving code
    save_dir = os.path.join(
        'connectivity_checkpoints',
        'between-domains',
        args.dataset_name,
        f'{args.domain_1}-{args.domain_2}-{args.version}-{args.transform}'
    )
    args.linear_probe_only = (args.swav_dir is not None) or (args.sentry_ft is not None)
    if args.linear_probe_only:
        if args.swav_dir is not None:
            save_dir += f'-{args.swav_dir.replace("/", "-")}-{args.swav_ckpt}'
        else:
            save_dir += f'-{args.sentry_ft.replace("/", "-")}'
    os.makedirs(save_dir, exist_ok=True)

    if args.sentry_filter:
        off_limits = data.get_domainnet_off_limits(VALID_DOMAINNET_S_DOMAINS)
        off_limits = { key: list(map(int, value)) for key, value in off_limits.items() }

    # do same-class
    print('*' * 10, 'Comparing same-class different-domain', '*' * 10)
    prng = np.random.RandomState(args.single_seed)
    classes = prng.choice(num_classes, size=args.num_pairs, replace=False)
    if args.range_start is None:
        args.range_start = -1
    if args.range_end is None:
        args.range_end = float('inf')
    for class_idx in classes:
        if args.sentry_filter:
            if (class_idx in off_limits[args.domain_1]) or (class_idx in off_limits[args.domain_2]):
                print(f'Skipping class {class_idx} due to class size filtering.')
                continue
        if args.range_start <= class_idx < args.range_end:
            print('*' * 10, f'Current class: {class_idx}.', '*' * 10)
            train_ds, test_ds = data.get_diff_domain_datasets(
                args.dataset_name, args.data_path, args.domain_1, args.domain_2,
                args.version, transform, class_idx, class_idx
            )
            identifier = f'same-class-{class_idx}'
            if args.save_model:
                identifier += '-save-model'
            train.main_loop(train_ds, test_ds, save_dir, identifier, args, save_model=args.save_model)

    if args.same_class_only:
        exit(0)

    # do different-class
    print('*' * 10, 'Comparing different-class different-domain', '*' * 10)
    class_pairs = train.get_classes_to_compare(num_classes, args.num_pairs, args.pair_seed)
    for class_1, class_2 in class_pairs:
        print('*' * 10, f'Current class pair: {class_1} and {class_2}.', '*' * 10)
        if args.sentry_filter:
            if (class_1 in off_limits[args.domain_1]) or (class_2 in off_limits[args.domain_2]):
                print(f'Skipping classes {class_1}, {class_2} due to class size filtering.')
                continue
        train_ds, test_ds = data.get_diff_domain_datasets(
            args.dataset_name, args.data_path, args.domain_1, args.domain_2,
            args.version, transform, class_1, class_2
        )
        identifier = f'different-class-{class_1}-{class_2}'
        if args.save_model:
            identifier += '-save-model'
        train.main_loop(train_ds, test_ds, save_dir, identifier, args, save_model=args.save_model)

if __name__ == '__main__':
    main()
