
import argparse
import builtins
import copy
import os
import random
import shutil
import time
import warnings
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from dataset import ImageDataset
from easydict import EasyDict as edict
from resnet import resnet18, resnet50,  load_state_dict_from_url, model_urls, wide_resnet50_2
from timm_model_wrapper import TimmWrapper
import timm
from pycls import models as pycls_models
import numpy as np
import json
import torch.optim as optim
from sklearn import metrics
import wilds_datasets


from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


#model_names = sorted(name for name in models.__dict__
#    if name.islower() and not name.startswith("__")
#    and callable(models.__dict__[name]))

model_names = ['regnet', 'resnet', 'resnet_timm', 'meal_v2']
dataset_names = ['domainnet', 'terraincognita', 'officehome', 'pacs', 'vlcs',  'wilds_fmow']

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('imagenet_train_dir', metavar='DIR',
                    help='path to dataset')
parser.add_argument('imagenet_val_dir', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet50)')
parser.add_argument('--dataset', default='domainnet', choices=dataset_names,
                    help='which dataset to use')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--eval_steps', default=-1, type=int,
                    help='number of eval_steps')
parser.add_argument('--epochs', default=5, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--steps', default=5000, type=int, metavar='N',
                    help='number of total steps to run')
parser.add_argument('--linear-steps', default=-1, type=int, metavar='N',
                    help='number of total steps to run')
parser.add_argument('--sma-start-iter', default=100, type=int, metavar='N',
                    help='Where to start model averaging.')
parser.add_argument('--accum-iter', default=1, type=int, metavar='N',
                    help='number of steps between updates')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--batch-size-list', default=None, type=int, nargs='*', help = 'per dataset batchsize')
parser.add_argument('--resample-batch-size', dest='resample_batch_size', action='store_true')
parser.add_argument('--lr', '--learning-rate', default=5e-5, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=0.0, type=float,
                    metavar='W', help='weight decay (default: 0.)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,  metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

parser.add_argument('--training_data', default=['sketch', 'real'], type=str, nargs='*', help= 'training subsets')
parser.add_argument('--validation_data', default=['painting'], type=str, nargs='*',  help = 'testing subsets')
parser.add_argument('--validation_mixture', default=None, type=int, nargs='*', help = 'testing subset mixture')


parser.add_argument('--save_name', default='', type=str,
                    help='name of saved checkpoint')
parser.add_argument('--save-dir', default='model_checkpoints/', type=str,
                    help='name of saved checkpoint')
parser.add_argument('--importance-loss-weight', default=0.0, type=float, dest='importance_loss_weight')
parser.add_argument('--alpha', default=0.0, type=float, dest='alpha')
parser.add_argument('--second-order', dest='second_order', action='store_true')
parser.add_argument('--freeze-bn', dest='freeze_bn', action='store_true')
parser.add_argument('--unknown', dest='unknown', action='store_true')
parser.add_argument('--num-mixtures', default=2, type=int,
                    help='number of models mixed')
parser.add_argument('--num-splits', default=0, type=int,
                    help='how many splits of data')
parser.add_argument('--split-idx', default=0, type=int,
                    help='which split of data')
parser.add_argument('--save-freq', default=-1, type=int,
                    help='how often to save checkpoints in steps')
parser.add_argument('--num-checkpoints', default=1, type=int, help='how many to save')
parser.add_argument('--train-val-split', default=-1, type=float, help='how much to split train val')

parser.add_argument('--projection-head', dest='projection_head', action='store_true')
parser.add_argument('--pretrained', dest='pretrained', action='store_true')
parser.add_argument('--wide', dest='wide', action='store_true')
parser.add_argument('--small', dest='small', action='store_true')
parser.add_argument('--strong-aug', dest='strong_aug', action='store_true')
parser.add_argument('--sma', dest='sma', action='store_true')
parser.set_defaults(unknown=False)
parser.set_defaults(second_order=False)
parser.set_defaults(strong_aug=True)
parser.set_defaults(pretrained=False)
parser.set_defaults(sma=False)
parser.set_defaults(freeze_bn=False)
parser.set_defaults(resample_batch_size=False)
parser.set_defaults(wide=False)
parser.set_defaults(small=False)
parser.set_defaults(projection_head=False)

best_acc1 = 0

class MovingAvg:
    def __init__ (self, network, ema = False, sma_start_iter=100):
        self.network = network
        self.network_sma = copy.deepcopy(network)
        self.sma_start_iter = sma_start_iter
        self.global_iter = 0
        self.sma_count = 0
        self.ema = ema
    def update_sma(self):
        self.global_iter += 1
        if self.global_iter >= self.sma_start_iter and self.ema :
        #if False:
            self.sma_count += 1
            for param_q, param_k in zip (self.network.parameters(), self.network_sma.parameters()):
                param_k.data = (param_k.data * self.sma_count + param_q.data )/(1.+ self.sma_count)
        else:
            for param_q,param_k in zip (self.network.parameters(), self.network_sma.parameters()):
                param_k.data = param_q.data


def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))


    # Data loading code
    traindir = args.imagenet_train_dir
    valdir = args.imagenet_val_dir
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))

    if args.strong_aug:
       train_transform = transforms.Compose([
            # transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


       val_transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                normalize,
            ])

    else:

        train_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        val_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])


    if args.dataset == 'domainnet':

        transform_dict = {'sketch': train_transform,
            'real': train_transform,
            'clipart': train_transform,
            'infograph': train_transform,
            'quickdraw': train_transform,
            'painting': train_transform
            }

        #transform_dict[args.validation_data] = val_transform

        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        for d in args.validation_data:
            transform_dict[d] = val_transform


        sketch_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/sketch', transform = transform_dict['sketch'])
        real_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/real', transform = transform_dict['real'])
        clipart_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/clipart', transform = transform_dict['clipart'])
        infograph_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/infograph', transform = transform_dict['infograph'])
        quickdraw_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/quickdraw', transform = transform_dict['quickdraw'])
        painting_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/domainnet/painting', transform = transform_dict['painting'])

        dataset_dict = {'sketch': sketch_dataset,
            'real': real_dataset,
            'clipart': clipart_dataset,
            'infograph': infograph_dataset,
            'quickdraw': quickdraw_dataset,
            'painting': painting_dataset,
            }

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])

        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 345

        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 345




    elif args.dataset == "terraincognita":

        transform_dict = {
                'location_100': train_transform,
                'location_38': train_transform,
                'location_43': train_transform,
                'location_46': train_transform
                }


        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        for d in args.validation_data:
            transform_dict[d] = val_transform



        location_100_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/terra_incognita/terra_incognita/location_100', transform = transform_dict['location_100'])
        location_38_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/terra_incognita/terra_incognita/location_38', transform = transform_dict['location_38'])
        location_43_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/terra_incognita/terra_incognita/location_43', transform = transform_dict['location_43'])
        location_46_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/terra_incognita/terra_incognita/location_46', transform = transform_dict['location_46'])

        dataset_dict = {'location_100': location_100_dataset,
                        'location_38': location_38_dataset,
                        'location_43': location_43_dataset,
                        'location_46': location_46_dataset}

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])


        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 10

        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 10





    elif args.dataset == "officehome":

        transform_dict = {
                'art': train_transform,
                'clipart': train_transform,
                'product': train_transform,
                'real': train_transform
                }


        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        for d in args.validation_data:
            transform_dict[d] = val_transform


        art_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/office_home/Art', transform = transform_dict['art'])
        clipart_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/office_home/Clipart', transform = transform_dict['clipart'])
        product_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/office_home/Product', transform = transform_dict['product'])
        real_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/office_home/Real', transform = transform_dict['real'])

        dataset_dict = {'art': art_dataset,
                        'clipart': clipart_dataset,
                        'product': product_dataset,
                        'real': real_dataset}

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])


        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 65

        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 65


    elif args.dataset == "pacs":

        transform_dict = {
                'art_painting': train_transform,
                'cartoon': train_transform,
                'photo': train_transform,
                'sketch': train_transform
                }


        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        for d in args.validation_data:
            transform_dict[d] = val_transform


        art_painting_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/PACS/art_painting', transform = transform_dict['art_painting'])
        cartoon_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/PACS/cartoon', transform = transform_dict['cartoon'])
        photo_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/PACS/photo', transform = transform_dict['photo'])
        sketch_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/PACS/sketch', transform = transform_dict['sketch'])

        dataset_dict = {'art_painting': art_painting_dataset,
                        'cartoon': cartoon_dataset,
                        'photo': photo_dataset,
                        'sketch': sketch_dataset}

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])

        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 7
        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 7

    elif args.dataset == "vlcs":

        transform_dict = {
                'caltech101': train_transform,
                'labelme': train_transform,
                'sun09': train_transform,
                'voc2007': train_transform
                }


        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        for d in args.validation_data:
            transform_dict[d] = val_transform


        caltech101_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/VLCS/Caltech101', transform = transform_dict['caltech101'])
        labelme_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/VLCS/LabelMe', transform = transform_dict['labelme'])
        sun09_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/VLCS/SUN09', transform = transform_dict['sun09'])
        voc2007_dataset = datasets.ImageFolder('/projectnb/REDACTED/REDACTEDt/data/VLCS/VOC2007', transform = transform_dict['voc2007'])

        dataset_dict = {'caltech101': caltech101_dataset,
                        'labelme': labelme_dataset,
                        'sun09': sun09_dataset,
                        'voc2007': voc2007_dataset}

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])

        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 5
        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 5


    elif args.dataset == "wilds_fmow":


        train_datasets = []
        val_datasets = []
        train_dataset_lengths = []
        test_dataset = None

        datasets_list = wilds_datasets.get_fmow(train_transform, val_transform, args.validation_data)

        region0_dataset = datasets_list[0]
        region1_dataset = datasets_list[1]
        region2_dataset = datasets_list[2]
        region3_dataset = datasets_list[3]
        region4_dataset = datasets_list[4]
        region5_dataset = datasets_list[5]


        dataset_dict = {'region0': region0_dataset,
                        'region1': region1_dataset,
                        'region2': region2_dataset,
                        'region3': region3_dataset,
                        'region4': region4_dataset,
                        'region5': region5_dataset}

        for d in args.training_data:
            train_datasets.append(dataset_dict[d])
            train_dataset_lengths.append(len(dataset_dict[d]))

        for d in args.validation_data:
            val_datasets.append(dataset_dict[d])


        if args.train_val_split > 0:
            datasets_split_train = []
            datasets_split_val = []
            for d in train_datasets:
                lengths = [int(len(d) * args.train_val_split)]
                lengths.append(len(d) - lengths[0])
                train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                train_split.dataset = copy.copy(d)
                train_split.dataset.transform = train_transform
                datasets_split_train.append(train_split)
            for idx, d in enumerate(val_datasets):
                lengths = [int(len(d) * args.train_val_split)]
                if args.validation_mixture:
                    lengths.append(args.validation_mixture[idx])
                    lengths.append(len(d) - args.validation_mixture[idx] - lengths[0])
                    train_split, val_split, _ = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                else:
                    lengths.append(len(d) - lengths[0])
                    train_split, val_split = torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))
                val_split.dataset.transform = val_transform
                datasets_split_val.append(val_split)

            train_datasets = datasets_split_train
            test_dataset = datasets_split_val
            num_classes = 62
        else:

            if args.num_splits > 1:
                train_datasets_split = []
                for d in train_datasets:
                    lengths = [len(d)//args.num_splits]*(args.num_splits - 1)
                    lengths.append(len(d) - sum(lengths))
                    train_datasets_split.append(torch.utils.data.random_split(d, lengths, torch.Generator().manual_seed(42))[args.split_idx])
                train_datasets = train_datasets_split


            test_dataset = val_datasets
            num_classes = 62




    if args.batch_size_list is None:
        args.batch_size_list = [args.batch_size] * len(train_datasets)


    train_loader = [torch.utils.data.DataLoader(
      train_dataset, batch_size=bs, shuffle=True,
      num_workers=args.workers, pin_memory=True, drop_last=True) for bs, train_dataset in zip(args.batch_size_list,train_datasets)]

    val_loader = torch.utils.data.DataLoader(
      torch.utils.data.ConcatDataset(test_dataset), batch_size=args.batch_size, shuffle=True,
      num_workers=args.workers, pin_memory=True, drop_last=False)






    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))
    # model = models.__dict__[args.arch]()
    if args.arch == 'resnet50':
        if args.wide:
            model = wide_resnet50_2(pretrained=args.pretrained, num_classes=num_classes, freeze_bn = args.freeze_bn, projection_head = args.projection_head)
        elif args.small:
            model = resnet18(pretrained=args.pretrained, num_classes=num_classes, freeze_bn = args.freeze_bn, projection_head = args.projection_head)
        else:
            model = resnet50(pretrained=args.pretrained, num_classes=num_classes, freeze_bn = args.freeze_bn, projection_head = args.projection_head)
    elif args.arch == 'resnet_timm':
        model = TimmWrapper(timm.create_model('resnet50', pretrained=True, num_classes=num_classes), args.freeze_bn)
    elif args.arch == 'meal_v2':
        model = timm.create_model('resnet50', pretrained=True, num_classes=num_classes)
        state_dict  = {k.split('module.')[-1]:v for (k,v) in torch.load('model_checkpoints/meal_v2/MEALV2_ResNet50_224.pth').items()}
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        model.load_state_dict(state_dict, strict=False)
    else:
        model = pycls_models.regnety("16GF", pretrained=False, cfg_list=("MODEL.NUM_CLASSES", num_classes))
        if args.pretrained:
            checkpoint = torch.load('model_checkpoints/regnet/RegNetY-16GF_dds_8gpu.pyth')
            del checkpoint['model_state']['head.fc.weight']
            del checkpoint['model_state']['head.fc.bias']
            model.load_state_dict(checkpoint['model_state'], strict=False)


    print(model)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Num params:{}'.format(params))

    # freeze all layers but the last fc
    # for name, param in model.named_parameters():
    #     if name not in ['fc.weight', 'fc.bias']:
    #         param.requires_grad = False
    # # init the fc layer
    # model.fc.weight.data.normal_(mean=0.0, std=0.01)
    # model.fc.bias.data.zero_()

    # load from pre-trained, before DistributedDataParallel constructor

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    # criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optimize only the linear classifier
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    # assert len(parameters) == 2  # fc.weight, fc.bias
    # optimizer = torch.optim.SGD(parameters, args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)



    optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                weight_decay=args.weight_decay)

    linear_parameters = []

    for n, p in model.named_parameters():
        if 'fc' in n:
            linear_parameters.append(p)

    linear_optimizer = torch.optim.Adam(linear_parameters, args.lr, weight_decay = args.weight_decay)
   # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            #args.start_epoch = checkpoint['epoch']
            #best_acc1 = checkpoint['best_acc1']
            # if args.gpu is not None:
            #     # best_acc1 may be from a checkpoint from a different GPU
            #     best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))




    cudnn.benchmark = True


    model = MovingAvg(model, args.sma, args.sma_start_iter)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    epochs = int(args.steps/len(train_loader))
    epoch = 0
    steps = 0
    save_iterate = 0

    #for epoch in range(args.start_epoch, epochs):
    while True:
        if steps > args.steps:
            break
        if args.distributed:
            train_sampler.set_epoch(epoch)

        steps, save_iterate = train(train_loader, val_loader,  model, criterion,  optimizer, linear_optimizer, epoch, args, steps, save_iterate)
        epoch = epoch + 1

    acc1 = validate(val_loader, model, criterion, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    if not args.multiprocessing_distributed or (args.multiprocessing_distributed
        and args.rank % ngpus_per_node == 0):
        save_name = '{}_{}'.format(args.save_name, save_iterate)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.network_sma.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }, is_best,save_name=save_name, save_dir = args.save_dir, save_iterate=save_iterate, num_checkpoints = args.num_checkpoints, args=args)




def train(train_loader, val_loader, model, criterion, optimizer, linear_optimizer, epoch, args, steps, save_iterate):
    global best_acc1
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    importance_loss_meter = AverageMeter('ImportanceLoss', ':.4e')

    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    # switch to train mode
    model.network.train()
    model.network_sma.train()

    train_loader_epoch = train_loader.copy()
    #train_loader_sketch, train_loader_real = train_loader

    #if len(train_loader_real) > len(train_loader_sketch):
    #    train_loader_main = train_loader_real
    #    aux_loader = train_loader_sketch
    #else:
    #    train_loader_main = train_loader_sketch
    #    aux_loader = train_loader_real

    train_loader_main_idx = np.argmax([len(d) for d in train_loader_epoch])
    train_loader_main = train_loader_epoch.pop(train_loader_main_idx)


    aux_iter_list = [iter(aux_loader) for aux_loader in train_loader_epoch]
    end = time.time()

    progress = ProgressMeter(
        len(train_loader_main),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))


    for i, (images, target) in enumerate(train_loader_main):
        if steps > args.steps:
            return steps, save_iterate
        if steps > args.linear_steps:
            selected_optimizer = optimizer
        else:
            selected_optimizer = linear_optimizer

        steps = steps + 1
        # measure data loading time
        data_time.update(time.time() - end)
        aux_images_list = []
        aux_target_list = []
        for idx,aux_iter in enumerate(aux_iter_list):
            try:
                aux_images, aux_target = next(aux_iter)
            except StopIteration:
                aux_iter_list[idx] = iter(train_loader_epoch[idx])
                aux_images, aux_target = next(aux_iter_list[idx])
            aux_images_list.append(aux_images)
            aux_target_list.append(aux_target)

        if args.gpu is not None or torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            for idx in range(len(aux_iter_list)):
                aux_images_list[idx] = aux_images_list[idx].cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)
            for idx in range(len(aux_iter_list)):
                aux_target_list[idx] = aux_target_list[idx].cuda(args.gpu, non_blocking=True)


        if args.resample_batch_size:
            total_batch_size = args.batch_size * (len(aux_iter_list) + 1)
            num_splits = len(aux_iter_list)  # num chunks -1
            batch_splits = list(range(total_batch_size))[1:]
            perm_splits = np.random.permutation(batch_splits)
            split_idx = np.sort(perm_splits[:num_splits])
            split_batch = np.split(range(total_batch_size), split_idx)
            batch_sizes = [len(s) for s in split_batch]
            for l,b in zip([train_loader_main] + train_loader_epoch, batch_sizes):
                l.batch_sampler.batch_size = b

        images = torch.concat([images] + aux_images_list, dim=0)
        target = torch.concat([target] + aux_target_list)
        # compute output

        output = model.network(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        max_k = 5
        acc1, acc5 = accuracy(output, target, topk=(1, max_k))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        loss = loss / args.accum_iter

        loss.backward()

        if ((i + 1) % args.accum_iter == 0) or (i + 1 == len(train_loader_main)):
            selected_optimizer.step()
            selected_optimizer.zero_grad()
            model.update_sma()
        if not args.freeze_bn:
            model.network_sma(images)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

        if args.save_freq > 0 and (steps % args.save_freq == 0):

            # evaluate on validation set
            acc1 = validate(val_loader, model, criterion, args)
            # switch to train mode
            model.network.train()
            model.network_sma.train()

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
                save_name = '{}_{}'.format(args.save_name, save_iterate)
                save_checkpoint({
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.network_sma.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer' : optimizer.state_dict(),
                }, is_best,save_name=save_name, save_dir = args.save_dir, save_iterate=save_iterate, num_checkpoints = args.num_checkpoints, args=args)

                save_iterate = save_iterate + 1


    return steps, save_iterate


def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.network.eval()
    model.network_sma.eval()
    end = time.time()
    for i, (images, target) in enumerate(val_loader):
        if i > args.eval_steps and args.eval_steps > 1:
            break
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model.network_sma(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        max_k = 5
        acc1, acc5 = accuracy(output, target, topk=(1, max_k))

        losses.update(loss.item(), images.size(0))

        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    progress.display_summary()

    return top1.avg

def save_checkpoint(state, is_best,save_name, save_dir, save_iterate = None, num_checkpoints=None, args=None):
    filename = save_dir +  'checkpoint_'+str(save_name)+'.pth.tar'
    torch.save(state, filename)
    if is_best:
        best_filename = save_dir + 'model_best_'+str(args.save_name)+'.pth.tar'
        shutil.copyfile(filename, best_filename)


def map(submission_array, gt_array):
    """ Returns mAP, weighted mAP, and AP array """
    submission_array = submission_array.cpu().detach().numpy()
    gt_array = gt_array.cpu().detach().numpy()
    m_aps = []
    n_classes = submission_array.shape[1]
    for oc_i in range(n_classes):
        sorted_idxs = np.argsort(-submission_array[:, oc_i])
        tp = gt_array[:, oc_i][sorted_idxs] == 1
        fp = np.invert(tp)
        n_pos = tp.sum()
        if n_pos < 0.1:
            m_aps.append(float('nan'))
            continue
        fp.sum()
        f_pcs = np.cumsum(fp)
        t_pcs = np.cumsum(tp)
        prec = t_pcs / (f_pcs+t_pcs).astype(float)
        avg_prec = 0
        for i in range(submission_array.shape[0]):
            if tp[i]:
                avg_prec += prec[i]
        m_aps.append(avg_prec / n_pos.astype(float))
    m_aps = np.array(m_aps)
    m_ap = np.nanmean(m_aps)
    w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
    return m_ap, w_ap, m_aps

def mAUC(submission_array, gt_array):
    """ Returns mAP, weighted mAP, and AP array """
    submission_array = submission_array.cpu().detach().numpy()
    gt_array = gt_array.cpu().detach().numpy()
    m_aucs = []
    n_classes = submission_array.shape[1]
    for cls in range(n_classes):
        fpr, tpr, thresholds = metrics.roc_curve(gt_array[:, cls], submission_array[:, cls])
        auc = metrics.auc(fpr,tpr)
        m_aucs.append(auc)
    m_auc = np.nanmean(m_aucs)
    return m_auc

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)

        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'




def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()
