import argparse
from email.policy import default
import time
import sys
import os
import json
import wandb
import random
from os import path, makedirs

import torch
from torch import optim
from torch.backends import cudnn
from algorithms.initializer import initialize_algorithm
from models.initializer import Identity

from transforms import initialize_transform
from SSL.loader import TwoCropsTransform
from SSL.model_factory import SimSiam, SimCLR
from SSL.criterion import SimSiamLoss, NT_XentLoss
from SSL.validation import Validation
from SSL.cfg import get_cfg
from SSL.utils import *
from wilds.common.data_loaders import get_ssl_train_loader, get_train_loader
from wilds.common.grouper import CombinatorialGrouper

import wilds
from utils import ParseKwargs, parse_bool, set_seed, SysLog, BatchLogger, get_model_module
import configs.supported as supported

LOG_DIR = './logs'

def get_args(dataset=None, algorithm='simsiam'):
    parser = argparse.ArgumentParser(description='SSL Training')
    parser.add_argument('-a', '--arch', metavar='ARCH', choices=supported.models)
    parser.add_argument('--model', choices=supported.models)

    parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ')

    parser.add_argument('--feat-dim', type=int, help='feature dimension')
    parser.add_argument('--num-proj-layers', type=int, default=2, help='number of projection layer')
    parser.add_argument('--batch-size', type=int, help='batch_size')
    parser.add_argument('--num-workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--n_epochs', type=int, help='number of training epochs')
    parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
    parser.add_argument('--loss-version', default='simplified', type=str,
                        choices=['simplified', 'original'],
                        help='do the same thing but simplified version is much faster. ()')
    parser.add_argument('--temperature', type=float, default=0.05, help='learning rate')
    parser.add_argument('--img-shuffle', default=None, choices=['block', 'pixel'])
    parser.add_argument('--n_blocks', type=int, help='number of blocks for block shuffling')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--eval-only', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-ckpt', type=str, default='', choices=['best', 'last', 'init', 'resume'])
    parser.add_argument('--eval-set', default='val')
    parser.add_argument('--train-set', default='ds_train')
    parser.add_argument('--minority-weight', type=float)

    parser.add_argument('--eval-knn', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--knn-k', default=5, type=int, help='k for k-NN monitor (default: 50)')
    parser.add_argument('--linreg_c', type=float)
    parser.add_argument('--eval-lin', type=parse_bool, const=True, nargs='?', default=True)
    parser.add_argument('--eval-mi', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-layer-wise', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-layer-wise-lr', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-layer-wise-knn', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-group-alignment', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--eval-layers-num', default=-1, type=int, help='number of final layers to evaluate on')
    parser.add_argument('--eval-freq', default=10, type=int, help='evaluate with the monitor every m epoch (default: 10)')
    parser.add_argument('--eval-grouping', type=parse_bool, const=True, nargs='?', default=True)
    parser.add_argument('--eval-train', type=parse_bool, const=True, nargs='?', default=True) # linear evaluation for train set
    parser.add_argument('--eval-spur', type=parse_bool, const=True, nargs='?', default=False) # linear evaluation for spuriously correlated attribute instead of true labels
    parser.add_argument('--eval-gridsearch', type=parse_bool, const=True, nargs='?', default=False) # do gridsearch or no
    parser.add_argument('--eval-sc', type=parse_bool, const=True, nargs='?', default=True) # use standardscaler while evaluation
    parser.add_argument('--eval-fullgrid', type=parse_bool, const=True, nargs='?', default=False) # do gridsearch or no
    parser.add_argument('--eval-connectivity', type=parse_bool, const=True, nargs='?', default=False) # eval alpha beta gamma
    parser.add_argument('--log-features', type=parse_bool, const=True, nargs='?', default=False) # visualize features
    parser.add_argument('--eval-norm', type=parse_bool, const=True, nargs='?', default=False) 
    parser.add_argument('--infer-group', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--shuffle-train', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--get-features', type=parse_bool, const=True, nargs='?', default=True)
    parser.add_argument('--save-features', type=parse_bool, const=True, nargs='?', default=False)

    parser.add_argument('--save-freq', default=50, type=int, help='save model frequency')
    parser.add_argument('--resume', default='last_checkpoint', type=str, help='path to latest checkpoint')

    parser.add_argument('--lr', type=float, help='learning rate')
    parser.add_argument('--lr_min', type=float,help='learning rate')
    parser.add_argument('--weight-decay', type=float, help='weight decay')
    parser.add_argument('--head-weight-decay', type=float, default=None, help='changed weight decay for head')
    parser.add_argument('--head-layers', type=float, default=0, help='number of layers we consider the -head- (0 if the projector only) ')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    # Model
    parser.add_argument('--width', type=int, default=-1, help='width for vwresnet models')
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--use_pretrained', type=parse_bool, const=True, nargs='?', default=False, help='Whether to fine-tune the trained model in pretrained_dir.')

    parser.add_argument('--model_aug', type=parse_bool, const=True, nargs='?', default=True)
    parser.add_argument('--final_aug', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--update_aug', default='epoch', type=str, choices=['batch', 'epoch'])
    parser.add_argument('--stop_grad', type=parse_bool, const=True, nargs='?', default=False)  # for simclr
    parser.add_argument('--prune_method', type=str, default='threshold', choices=['-', 'global', 'threshold'])
    parser.add_argument('--prune_criterion', type=str, default='magnitude', choices=['magnitude'])
    parser.add_argument('--prune_perc', type=float, default=0)
    parser.add_argument('--prune_inc_perc', type=float, default=0.5)
    parser.add_argument('--prune_layer_th', type=int, default=-1)
    parser.add_argument('--reinit_method', type=str, default='-', choices=['-', 'threshold'])
    parser.add_argument('--reinit_layer_th', type=int, default=-1)

    parser.add_argument('--aug-epoch', type=int, default=-1, help='start model augmentations after a certain number of epochs')

    # wilds ds
    # Dataset
    parser.add_argument('--dataset', choices=wilds.supported_datasets + ['cifar10'], default=dataset)
    parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
    parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--frac', type=float, default=1.0,
                        help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.')
    parser.add_argument('--train-frac', type=float, default=1.0)
    parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')

    # Loaders
    parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')
    parser.add_argument('--ssl_uniform_over_groups', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--ssl-train-set', default='train', choices=['val', 'train', 'balanced_train', 'ds_train', 'us_train'])
    parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--eval_loader', choices=['standard'], default='standard')
    parser.add_argument('--algorithm', default=algorithm, choices=supported.ssl_algorithms + ['ERM'])
    parser.add_argument('--adapt', type=parse_bool, const=True, nargs='?', default=False)
    parser.add_argument('--weights-model', default='', choices=['', 'simsiam-in', 'simsiam-base', 'erm-in', 'erm-base', 'clip', 'simclr-in', 'barlow-in']) # evaluate representations of any other 'algorithm'-like model
    parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--log_dir', default=LOG_DIR)
    parser.add_argument('--no-hparam', type=parse_bool, const=True, nargs='?', default=True) # have hparams in dir name
    parser.add_argument('--use-wandb', type=parse_bool, const=True, nargs='?', default=True)
    parser.add_argument('--wandb-name', default='train-ssl')
    parser.add_argument('--entity', default='kimia')

    parser.add_argument('--output_dir', default='./outs')
    parser.add_argument('--experiment', default='experiment')
    return parser

def setup_run(config, args):
    if args.seed is not None:
        set_seed(config.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
    print(vars(args))
    if config.use_wandb:
        wandb.init(name=config.name, project=args.wandb_name, entity=config.entity, settings=wandb.Settings(start_method='thread'))
        wandb.config.update(config)
    # add another logger
    f = open(os.path.join(config.log_dir, f'log.txt'), 'w')
    sys.stdout = SysLog(sys.stdout, f)

def get_dataset(config, train=True, download=False):
    dataset = config.dataset
    full_dataset = wilds.get_dataset(
        dataset=config.dataset,
        root_dir='',
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)
    train_grouper = CombinatorialGrouper(dataset=full_dataset, groupby_fields=config.groupby_fields)

    train_transform = initialize_transform(
                    transform_name=config.transform_ssl,
                    config=config,
                    dataset=full_dataset,
                    is_training=True)
    print('SSL Transform: ', train_transform)
    transform = TwoCropsTransform(train_transform)
    print(transform)
    dataset = full_dataset.get_subset(config.ssl_train_set, frac=1., transform=transform)
    # if dataset == 'cifar10':
        # dataset = datasets.CIFAR10(data_dir, train=train, transform=transform, download=download)
    return dataset, train_grouper

def setup_dataset(config, args):
    train_set, train_grouper = get_dataset(config)
    
    if config.eval_only:
        return train_set, None, train_grouper

    train_loader = get_ssl_train_loader(
        loader='standard',
        dataset=train_set,
        batch_size=args.batch_size,
        uniform_over_groups=config.ssl_uniform_over_groups,
        grouper=train_grouper,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True, 
        # **config.loader_kwargs
    )

    return train_set, train_loader, train_grouper

def setup_training(config, args, train_loader):
    if config.eval_only and config.eval_ckpt == 'init':
        from models.pretrained import get_enc
        model = get_enc(config, config.weights_model, get_rep_dim=False).cuda() # only the encoder
        criterion = None

    elif args.algorithm == 'simsiam':
        args.feat_dim = 2048
        model = SimSiam(config)
        criterion = SimSiamLoss(args.loss_version)
    elif args.algorithm == 'simclr':
        args.feat_dim = 128
        model = SimCLR(config)
        criterion = NT_XentLoss(args.loss_version, temperature=args.temperature)
    else:
        assert config.eval_only
        config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        return None, None, None, None, None
    
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model).cuda()
            model_module = model.module
            if criterion is not None:
                criterion = criterion.cuda()
            cudnn.benchmark = True
        else:
            args.gpu = 0 if args.gpu is None else args.gpu
            torch.cuda.set_device(args.gpu)
            model = model.cuda(args.gpu)
            model_module = model
            if criterion is not None:
                criterion = criterion.cuda(args.gpu)
            cudnn.benchmark = True
        device_count = torch.cuda.device_count()
    else:
        device_count = 1
        model_module = model
    args.num_workers = device_count * 4

    if config.eval_only:
        return model, model_module, None, None, None

    if args.head_weight_decay:
        optim_params = [{'params': model_module.encoder[1].parameters(), 'weight_decay': args.head_weight_decay}]
        if args.algorithm == 'simsiam':
            optim_params.append({'params': model_module.predictor.parameters()})
        if args.head_layers == 0:
            optim_params.append({'params': model_module.encoder[0].parameters(), 'weight_decay': args.head_weight_decay})    
        elif args.head_layers == 1:
            optim_params += [
                            {'params': model_module.encoder[0].layer4.parameters(), 'weight_decay': args.head_weight_decay},
                            {'params': model_module.encoder[0].layer3.parameters()},
                            {'params': model_module.encoder[0].layer1.parameters()},
                            {'params': model_module.encoder[0].layer2.parameters()}]
        elif args.head_layers == 2:
            optim_params += [
                            {'params': model_module.encoder[0].layer4.parameters(), 'weight_decay': args.head_weight_decay},
                            {'params': model_module.encoder[0].layer3.parameters(), 'weight_decay': args.head_weight_decay},
                            {'params': model_module.encoder[0].layer1.parameters()},
                            {'params': model_module.encoder[0].layer2.parameters()}]
    else:
        optim_params = model.parameters()
    optimizer = optim.SGD(optim_params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    
    scheduler = None
    if config.algorithm == 'simclr':
        def get_lr(step, total_steps, lr_max, lr_min):
            """Compute learning rate according to cosine annealing schedule."""
            return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
        from torch.optim.lr_scheduler import LambdaLR
        # cosine annealing lr
        if config.eval_only: 
            scheduler = None
        else:
            scheduler = LambdaLR(
                optimizer,
                lr_lambda=lambda step: get_lr(
                    step,
                    args.n_epochs * len(train_loader),
                    args.lr,
                    1e-3))

    return model, model_module, criterion, optimizer, scheduler

def resume_model(config, args, model, optimizer):
    start_epoch = 1
    
    if config.eval_ckpt == 'init' and config.weights_model != 'simsiam-base':
        return start_epoch
    
    if config.eval_only:
        eval_path1 = ''
        if config.eval_ckpt == 'resume':
            eval_path1 = path.join(config.log_ckpts_dir, f'ckpt.pth.tar')
        eval_path = path.join(config.log_ckpts_dir, f'ckpt_{config.eval_ckpt}.pth.tar')
        if config.weights_model == 'simsiam-base' and config.eval_ckpt == 'init':
            eval_path = args.resume
        if not path.isfile(eval_path):
            if path.isfile(eval_path1):
                eval_path = eval_path1
            print('eval path does not exist: ', eval_path1, eval_path)
        assert path.isfile(eval_path)
        
        start_epoch, model, optimizer = load_checkpoint(model, optimizer, eval_path)
        print("Loaded checkpoint '{}' (epoch {})"
                .format(args.resume, start_epoch))
        
        return start_epoch
        
    print('in resume model, ', args.resume)
    if args.resume is not None:
        if args.resume == 'last':
            args.resume = path.join(config.log_ckpts_dir, 'ckpt_last.pth.tar')
        if not path.isfile(args.resume) and args.resume == 'last_checkpoint':
            args.resume = path.join(config.log_ckpts_dir, 'ckpt.pth.tar')
        if path.isfile(args.resume):
            start_epoch, model, optimizer = load_checkpoint(model, optimizer, args.resume)
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, start_epoch))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            if args.resume == 'last':
                args.resume = path.join(config.log_ckpts_dir, 'ckpt_last.pth.tar')
            if not path.isfile(args.resume) and args.resume == 'last_checkpoint':
                args.resume = path.join(config.log_ckpts_dir, 'ckpt.pth.tar')
            if path.isfile(args.resume):
                start_epoch, model, optimizer = load_checkpoint(model, optimizer, args.resume)
                print("Loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, start_epoch))
            else:
                print("No checkpoint found at '{}'".format(args.resume))
                if config.use_pretrained:
                    start_epoch, model, optimizer = load_pretrained_checkpoint(model, optimizer, args.pretrained_dir)
                    print("Loaded pretrianed checkpoint '{}' (epoch {})"
                    .format(args.resume, start_epoch))
                    start_epoch = 1
    
    if config.eval_only:
        FileExistsError('No checkpoint to evaluate from')
    return start_epoch

def evaluate_best(config, model, validation):
    terminate = False

    # routine
    best_acc = 0.0
    best_path_dir = path.join(config.log_ckpts_dir, 'ckpt_best.pth.tar')
    
    if os.path.exists(best_path_dir):
        checkpoint = torch.load(best_path_dir, map_location='cuda:0')
        best_acc = checkpoint['top1_acc']
        epoch = checkpoint['epoch']
        print(f'best acc found in previous checkpoints: {best_acc}')
        if config.eval_only and config.eval_ckpt == 'best':
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            validation.eval(epoch)
            terminate = True
    else:
        if config.eval_only and config.eval_ckpt == 'best':
            print(f'no checkpoint found in dir: {best_path_dir}')
            terminate = False
    
    return terminate, best_acc, best_path_dir

def initialize_wilds_datasets(config, logger):
    from wilds.common.data_loaders import get_eval_loader
    from collections import defaultdict

    full_dataset = wilds.get_dataset(
        dataset=config.dataset,
        root_dir='',
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)

    eval_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        is_training=False)
    
    print(eval_transform)

    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=config.groupby_fields)
    
    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        ds_split = split
        if split=='train':
            transform = eval_transform
            verbose = True
            ds_split = config.train_set
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        
        datasets[split]['dataset'] = full_dataset.get_subset(
            ds_split,
            frac=config.frac,
            transform=transform)

        datasets[split]['loader'] = get_eval_loader(
            loader='standard',
            dataset=datasets[split]['dataset'],
            grouper=train_grouper,
            batch_size=config.batch_size,
            **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        mode = 'w'
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
        
    from utils import log_group_data
    log_group_data(datasets, train_grouper, logger)
    
    return full_dataset, datasets

def evaluate_algorithm(config, args, train_grouper):  # for when algorithm is not supported by ssl models
    if config.weights_model == 'erm-in':
        assert config.model_kwargs['pretrained'] == True
        assert config.algorithm == 'ERM'
    from utils import Logger
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), 'w')

    full_dataset, datasets = initialize_wilds_datasets(config, logger)
    algorithm = initialize_algorithm(config, datasets, train_grouper)
        
    ckpt_path_dir = path.join(f'{LOG_DIR}/{config.dataset}_{config.algorithm}_{config.model}', 
                            f'{config.dataset}_seed:{config.seed}_epoch:{config.eval_ckpt}_model.pth')
    
    print(ckpt_path_dir)
    if config.weights_model != 'erm-in':
        if os.path.exists(ckpt_path_dir):
            from utils import load
            best_epoch, best_val_metric = load(algorithm, ckpt_path_dir, True)
        else:
            print(f'no checkpoint found in dir: {ckpt_path_dir}')
            import errno
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), ckpt_path_dir)

    from utils import evaluate
    config.evaluate_all_splits = True
    config.progress_bar = False
    config.save_pred = False
    
    evaluate(
        algorithm=algorithm,
        datasets=datasets,
        epoch=-1,
        general_logger=logger,
        config=config,
        is_best=False)
    algorithm.eval()
    model = algorithm.model
    try:
        model_module = model
        rep_dim = model.fc.in_features
        model.fc = Identity(rep_dim)
    except:
        model_module = model.module
        rep_dim = algorithm.model.module.fc.in_features
        algorithm.model.module.fc = Identity(rep_dim)
    from utils import quick_eval
    validation = Validation(args, config, model_module, rep_dim=rep_dim, grouper=train_grouper, eval_set=config.eval_set, train_set=config.train_set)
    config.eval_ckpt = 'last'
    res = validation.eval(-1)
    return res

def initialize_encoder(config, model_module):
    if config.weights_model in ['', 'clip', 'simclr-in', 'barlow-in']:  # already handled
        return
    
    if config.weights_model == 'erm-in':
        assert config.model_kwargs['weights'] is not None
        return
    if config.weights_model == 'erm-base':
        algorithm = 'ERM'
        ckpt_path_dir = path.join(f'/h/kimia/robust-ssl/log_erm/{config.dataset}_{algorithm}_{config.model}', 
                        f'{config.dataset}_seed:{config.seed}_epoch:last_model.pth')
        print(f'weight model loaded model from {ckpt_path_dir}')
        if os.path.exists(ckpt_path_dir):  
            state_dict = torch.load(ckpt_path_dir, map_location='cuda:0')['algorithm']
            state_dict = {k.replace('model.module.', ''): v for k, v in state_dict.items()}
            state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc')}
            model_module.load_state_dict(state_dict)
        else:
            print(f'no checkpoint found in dir: {ckpt_path_dir}')
            import errno
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), ckpt_path_dir)
        return
    
    # loading encoder only 
    if config.weights_model == 'simsiam-in':
        model_path = 'pt_models/SSL/simsiam_modified'
        ckpt_path_dir = f'{LOG_DIR}/{model_path}.pth'
    elif config.weights_model == 'simsiam-base':
        # todo: assert adaptor
        pt_algorithm = 'simsiam'
        dir_name = f'{config.dataset}_{pt_algorithm}_{config.model}'
        model_name = 'ckpts/ckpt_last.pth.tar'
        ckpt_path_dir = f'{LOG_DIR}/{dir_name}/{model_name}'
    else:
        KeyError(f'weights model {config.weights_model} not handled')
    print(f'weight model loaded model from {ckpt_path_dir}')
    checkpoint = torch.load(ckpt_path_dir, map_location='cuda:0')
    best_acc = checkpoint['top1_acc'] if 'top1_acc' in checkpoint.keys() else 0
    epoch = checkpoint['epoch']
    print(f'acc found in weights_model checkpoint: {best_acc}')
    state_dict = {k.replace('module.encoder.0.', '').replace('module.encoder.', '').replace('model.module.', ''): v for k, v in checkpoint['state_dict'].items()}
    state_dict = {k: v for k, v in state_dict.items() if not (k.startswith('module.') or k.startswith('1.'))} # k.startswith('module.backbone') or k.startswith('module.projector') 
    model_module.load_state_dict(state_dict)

def get_featurizer(config, args, model, model_module, optimizer):
    if config.eval_only and config.eval_ckpt == 'init':
        start_epoch = -1
        initialize_encoder(config, model_module)
        featurizer = model
        rep_dim = get_model_module(featurizer).fc.in_features if config.weights_model != 'clip' else model_module.attnpool.c_proj.out_features
    else:
        start_epoch = resume_model(config, args, model, optimizer)
        print('start epoch ', start_epoch)
        print('resumed model from ckpt: ', args.eval_ckpt, args.resume)
        featurizer = model_module.encoder
        if config.algorithm in ['simsiam', 'simclr']:
            featurizer = featurizer[0]
        
        if torch.cuda.device_count() > 1:
            featurizer = torch.nn.DataParallel(featurizer).cuda()
        rep_dim = model_module.rep_dim
        featurizer.eval()
    return featurizer, rep_dim, start_epoch

def eval_grid(config, validation, featurizer, train_grouper, grid=['all']):
    from SSL.validation import save_features, eval_grid_search, delete_features
    featurizer.eval()
    loaders = {config.train_set: validation.train_dataloader, config.eval_set: validation.val_dataloader}
    print('og count ', validation.train_dataloader.dataset.og_group_counts)
    if config.minority_weight is not None:
        loaders[config.train_set] = get_train_loader('standard',
                    dataset=validation.train_dataloader.dataset, batch_size=config.batch_size, grouper=train_grouper, minority_weight=config.minority_weight)
    loaders.update({k: v['loader'] for k, v in validation.extra_vals.items()})
    feature_dict = save_features(config, featurizer, loaders, grouper=train_grouper)  # choose none if you don't want to log embeddings
    if config.eval_fullgrid:
        eval_grid_search(config, mode=['all', 'noreg'], grouper=train_grouper, feature_dict=feature_dict)
    else:
        eval_grid_search(config, mode=grid, grouper=train_grouper, feature_dict=feature_dict)
    if config.eval_norm:
        eval_grid_search(config, mode=['noreg', 'svm'], grouper=train_grouper)
    
    delete_features(config, loaders.keys())
    return

def eval_connectivity(config, train_grouper, featurizer):
    full_dataset = wilds.get_dataset(dataset=config.dataset, root_dir='', split_scheme=config.split_scheme, **config.dataset_kwargs)
    ssl_transform = initialize_transform(
        transform_name=config.transform_ssl,
        config=config,
        dataset=full_dataset,
        is_training=True)
    from SSL.validation import get_connectivities
    connectivity = get_connectivities(config, full_dataset, train_grouper, featurizer, ssl_transform)
    a, b, g = 'alpha', 'beta', 'gamma'

    import pandas as pd
    df = pd.DataFrame({'metric': [a, b, g, g, b, a], 
                    'connectivity': [v[0] for k, v in connectivity.items()],
                    'groups': [(train_grouper.group_field_str(k[0]), train_grouper.group_field_str(k[1])) for k, v in connectivity.items()] })
    print(df)
    d = df.groupby('metric').agg('connectivity').mean().to_dict()
    if config.use_wandb:
        wandb.log(d)
    
    return d


def evaluate_model(config, args, model, validation, epoch=0):
    epoch = 0
    is_best, res = validation.eval(epoch)
    return res

def finish_run(args):
    with open(f'{args.output_dir}/done.out', 'w', encoding='utf-8') as f:
        f.write('done')

def main():
    parser = get_args()
    args = parser.parse_args()
    config, args = get_cfg(args)
    print('printing shuffle train: ', config.shuffle_train)
    setup_run(config, args)
    train_set, train_loader, train_grouper = setup_dataset(config, args)
    print(train_set, train_loader)
    model, model_module, criterion, optimizer, scheduler = setup_training(config, args, train_loader)
    if config.eval_only and config.algorithm not in supported.ssl_algorithms:
        evaluate_algorithm(config, args, train_grouper)
        finish_run(args)
        return
    
    featurizer, rep_dim, start_epoch = get_featurizer(config, args, model, model_module, optimizer)
    if config.eval_only and config.final_aug:
        print('augmenting the featurizer')
        from SSL.model_aug import Augmenter
        augmenter = Augmenter(config, copy=False)
        augmenter.augment(featurizer)
    
    validation = Validation(args, config, featurizer, rep_dim=rep_dim, grouper=train_grouper, eval_set=config.eval_set, train_set=config.train_set, extra_evals=config.extra_vals)

    if config.adapt and not config.eval_only:
        from copy import deepcopy
        initial_model = deepcopy(featurizer)
        from SSL.model_factory import drop_adapter
        drop_adapter(config, initial_model)
        val2 = Validation(args, config, initial_model, rep_dim=model_module.rep_dim, grouper=train_grouper, eval_set=config.eval_set, train_set=config.train_set)
        val2.eval(-1)

    if not config.eval_only:
        terminate, best_acc, best_path_dir = evaluate_best(config, model, validation)
        if terminate:
            finish_run(args)
            return

    if config.eval_fullgrid or config.eval_norm:
        wandb.log({'eval_epoch': start_epoch})
        eval_grid(config, validation, featurizer, train_grouper, grid=None)
        if not config.eval_connectivity:
            finish_run(args)
            return
    
    if config.eval_connectivity:
        eval_connectivity(config, train_grouper, featurizer)
        finish_run(args)
        return

    if config.eval_only:
        evaluate_model(config, args, model, validation)
        finish_run(args)
        return

    # stop model augmentations if they should be applied after a certain epoch
    true_model_aug = config.model_aug
    if config.aug_epoch != -1:
        assert config.model_aug
        config.model_aug = False
    
    acc = 0
    epoch = start_epoch
    for epoch in range(start_epoch, args.n_epochs+1):
        if true_model_aug and epoch > config.aug_epoch:
            config.model_aug = True

        adjust_learning_rate(optimizer, epoch, args)
        print("Training...")

        # train for one epoch
        train_loss = train(config, train_loader, model, criterion, optimizer, epoch, args, scheduler=scheduler)
        if config.use_wandb:
            wandb.log({'Loss/train': train_loss})

        if config.model_aug and config.update_aug == 'epoch':
            model_module.update_tr_encoder()

        if epoch % args.eval_freq == 0:
            is_best, res = validation.eval(epoch)
            acc = res['lin'][0]
            if acc > best_acc and epoch > 50:
                best_acc = acc

                save_checkpoint(args, epoch, model, optimizer, best_acc,
                                best_path_dir,
                                'Saving the best model!')
            
                with open(f'{config.log_ckpts_dir}/config.json', 'w') as fp:
                    json.dump(vars(config), fp,  indent=4)

        # save the model
        if epoch % args.save_freq == 0:
            save_checkpoint(args, epoch, model, optimizer, acc,
                            path.join(config.log_ckpts_dir, 'ckpt.pth.tar'),
                            'Saving...')
            if epoch % 100 == 0:
                if config.dataset in ['bgchallenge', 'hard_imagenet']:
                    save_checkpoint(args, epoch, model, optimizer, acc,
                                    path.join(config.log_ckpts_dir, f'ckpt_{epoch}.pth.tar'),
                                    'Saving...')
            save_checkpoint(args, epoch, model, optimizer, acc,
                            path.join(config.ckpts_dir, f'ckpt_last.pth.tar'),
                            'Saving...')

    print('Best accuracy:', best_acc)
    # final validation
    validation.eval(epoch)
    # save model
    save_checkpoint(args, epoch, model, optimizer, acc,
                    path.join(config.log_ckpts_dir, 'ckpt_last.pth.tar'),
                    'Saving the model at the last epoch.')
    if path.exists(path.join(config.log_ckpts_dir, 'ckpt.pth.tar')):
        os.remove(path.join(config.log_ckpts_dir, 'ckpt.pth.tar'))

    finish_run(args)
    return

def train(config, train_loader, model, criterion, optimizer, epoch, args, scheduler=None):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, data in enumerate(train_loader):
        images = data[0]
        images[0] = images[0].cuda(non_blocking=True)
        images[1] = images[1].cuda(non_blocking=True)

        # compute output
        if config.model_aug and config.update_aug == 'batch':
            model.module.update_tr_encoder()

        outs = model(images[0], images[1])
        loss = criterion(outs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        # measure elapsed time
        losses.update(loss.item(), images[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        
        if i % args.print_freq == 0:
            progress.display(i)
    
    
    return losses.avg


if __name__ == '__main__':
    main()
