import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')

import os
import re
import numpy as np
import time
import json
import copy
from tqdm import tqdm

import datasets
from models import get_model
from models.utils import get_transform, initialize_attention_layers
from datasets.common import get_dataloader, maybe_dictionarize
from args import get_args
from utils import *
from visualization import visualize_imgs_and_preds, visualize_kernel_weights, plot_activation_rates


def main():
    args = get_args(verbose=True)
    set_seed(args.seed)

    ngpus_per_node = torch.cuda.device_count() if torch.cuda.is_available() else 0
    
    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1

    if args.kd or args.rd or args.crd:
        """
        Distributed data parallel
        """
        if args.distributed:
            args.world_size = ngpus_per_node * args.world_size
            mp.spawn(distill, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
        else:
            distill(args.gpu, ngpus_per_node, args)
    else:
        # visualize_kernel_weights(model.pre_featurizer, args.model_name,
            # filename=os.path.join(args.result_dir, 'weight_visualization', f'first_conv_weights_{args.model_name}'))
        # exit(0)

        """
        Distributed data parallel
        """
        if args.distributed:
            args.world_size = ngpus_per_node * args.world_size
            if args.eval:
                evaluate(args.gpu, ngpus_per_node, args, input_key='images', visualize_predictions=False, visualize_activation_rates=False)
            else:
                input_key = 'features' if args.feature_cache_dir is not None else 'images'
                if input_key == 'images':
                    mp.spawn(train, nprocs=ngpus_per_node, args=(ngpus_per_node, args, input_key, False))
                else:
                    train(args.gpu, ngpus_per_node, args, input_key, visualize_predictions=False)
        else:
            if args.eval:
                # evaluate(args.gpu, ngpus_per_node, args, input_key='images', visualize_predictions=True)
                evaluate(args.gpu, ngpus_per_node, args, input_key='images', visualize_predictions=False, ensemble=args.ensemble, visualize_activation_rates=True)
            else:
                input_key = 'features' if args.feature_cache_dir is not None else 'images'
                train(args.gpu, ngpus_per_node, args, input_key=input_key, visualize_predictions=False)


def setup_model(args, load=False, finetune=False, from_scratch=False, include_top=True):
    if args.gpu is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
    
    model = get_model(args, finetune=finetune, from_scratch=from_scratch, include_top=include_top)
    # Load model
    if load and args.load_dir is not None:
        if args.load_dir.endswith('.pt'):
            checkpoint = torch.load(args.load_dir)
            checkpoint = remove_prefix_in_checkpoints(checkpoint)
            try:
                model.pre_featurizer.load_state_dict(checkpoint)
                print(f'Loaded pre-featurizer weights from {args.load_dir}')
            except:
                model.load_state_dict(checkpoint)
                print(f'Loaded model weights from {args.load_dir}')
        else:
            model.featurizer.load_state_dict(torch.load(os.path.join(args.load_dir, 'featurizer.pt')))
            if include_top:
                model.classification_head.load_state_dict(torch.load(os.path.join(args.load_dir, 'classifier.pt')))
    if include_top and args.classifier_load_dir is not None:
        checkpoint = torch.load(args.classifier_load_dir)
        checkpoint = remove_prefix_in_checkpoints(checkpoint)
        model.classification_head.load_state_dict(checkpoint)
        print(f'Loaded classifier weights from {args.classifier_load_dir}')
    model.to(device)
    return model, device


def train(gpu, ngpus_per_node, args, input_key='features', visualize_predictions=False):
    args.gpu = gpu
    if args.distributed:
        if args.dist_url == 'env://' and args.rank == -1:
            args.rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
    
    model, device = setup_model(args, load=args.lp or args.resume, finetune=not args.lp)

    """
    Vanilla data parallel
    """
    if ngpus_per_node > 1 and not args.distributed:
        model = nn.DataParallel(model)
        model.input_resolution = model.module.input_resolution
        model.pre_featurizer = model.module.pre_featurizer
        model.featurizer = model.module.featurizer
        model.classification_head = model.module.classification_head
    
    if args.distributed:
        if torch.cuda.is_available():
            if args.gpu is not None:
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs of the current node.
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
                model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            else:
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = nn.parallel.DistributedDataParallel(model)
            
        model.input_resolution = model.module.input_resolution
        model.pre_featurizer = model.module.pre_featurizer
        model.featurizer = model.module.featurizer
        model.classification_head = model.module.classification_head

    if input_key == 'images':
        aug = not args.lp
        image_encoder = None
        if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel):
            model.module.set_full_forward(True)
        else:
            model.set_full_forward(True)
    else:
        aug = False
        image_encoder = model.pre_featurizer
        
    dataset_class = getattr(datasets, 'ImageNetTrain')
    dataset = dataset_class(get_transform(args, model.input_resolution, aug=aug), location=args.data_dir, batch_size=args.batch_size,
                            num_workers=args.num_workers, distributed=args.distributed, pin_memory=args.pin_memory)
    num_batches = len(dataset.train_loader)
    train_loader, train_sampler = get_dataloader(dataset, is_train=True, batch_size=args.batch_size, device=device, image_encoder=image_encoder,
                                  distributed=args.distributed, pin_memory=args.pin_memory)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.ls)

    params = [param for param in model.parameters() if param.requires_grad]
    optimizer = optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay)
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs*num_batches)
    scheduler = cosine_lr_with_warmup(optimizer, args.lr, args.warmup_length, args.epochs*num_batches,
                                      args.lr_warm_restarts, args.restart_epochs*num_batches)

    if args.lp_ft:
        optimizer_lp = optim.AdamW(get_trainable_params(model), lr=args.lr_lp, weight_decay=args.weight_decay)

    args.current_epoch = 0
    # evaluate(gpu, ngpus_per_node, args, model=model, input_key=input_key)

    # Train loop
    if args.resume:
        starting_epoch = int(re.findall(r'\d+', args.load_dir)[-1]) + 1
    else:
        starting_epoch = 1
    
    for epoch in range(starting_epoch, args.epochs+starting_epoch):
        if args.lp_ft:
            if epoch + starting_epoch == args.epochs // 2 + 1:  # FT phase
                print('Fine-tuning phase.')
            elif epoch == starting_epoch:  # LP phase
                print('Linear probing phase.')
                
        if args.distributed:
            train_sampler.set_epoch(epoch-starting_epoch)
        model.train()
        
        for i, batch in enumerate(tqdm(train_loader)):
            start_time = time.time()
            step = i + (epoch - starting_epoch) * num_batches
            if not args.lp_ft or (epoch + starting_epoch) > args.epochs // 2:
                scheduler(step)

            batch = maybe_dictionarize(batch)
            x = batch[input_key].to(device)
            y = batch['labels'].to(device)
            data_time = time.time() - start_time

            logits = model(x)
            loss = criterion(logits, y)
            if args.lp_ft and (epoch + starting_epoch) <= args.epochs // 2:
                optimizer_lp.zero_grad()
            else:
                optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(params, 1.0)
            if args.lp_ft and (epoch + starting_epoch) <= args.epochs // 2:
                optimizer_lp.step()
            else:
                optimizer.step()

            batch_time = time.time() - start_time
            if i % args.print_freq == 0:
                percent_complete = 100 * i / len(train_loader)
                print(
                    f"Train Epoch: {epoch}/{args.epochs+starting_epoch-1} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
                )

        # Save model
        if not args.distributed or (args.distributed and args.rank % ngpus_per_node == 0):
            if epoch % 1 == 0 and args.result_dir is not None:
                model_save_dir = os.path.join(args.result_dir, 'models')
                os.makedirs(model_save_dir, exist_ok=True)
                # model_path_feat = os.path.join(model_save_dir, f'featurizer_{args.exp_name}_epoch{epoch}.pt')
                # model_path_clf = os.path.join(model_save_dir, f'classifier_{args.exp_name}_epoch{epoch}.pt')
                model_path_clf = os.path.join(model_save_dir, f'model_{args.exp_name}_epoch{epoch}.pt')
                print('Saving model to', model_save_dir)
                # torch.save(model.featurizer.state_dict(), model_path_feat)
                # torch.save(model.classification_head.state_dict(), model_path_clf)
                torch.save(model.state_dict(), model_path_clf)
                # optim_path = os.path.join(model_save_dir, f'optim_exp{args.exp_name}_epoch{epoch}.pt')
                # torch.save(optimizer.state_dict(), optim_path)

            # Evaluate
            args.current_epoch = epoch
            if epoch == args.epochs:
                evaluate(gpu, ngpus_per_node, args, model=model, input_key=input_key, visualize_predictions=visualize_predictions)
                if args.ensemble:
                    evaluate(gpu, ngpus_per_node, args, model=model, input_key=input_key, ensemble=True)
            else:
                evaluate(gpu, ngpus_per_node, args, model=model, input_key=input_key)


def distill(gpu, ngpus_per_node, args):
    input_key = 'images'
    aug = True

    args.gpu = gpu
    if args.distributed:
        if args.dist_url == 'env://' and args.rank == -1:
            args.rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
    
    if args.rd or args.crd:
        include_top = False
    else:
        include_top = True
    
    student_model, device = setup_model(args, load=args.resume, finetune=False, from_scratch=True, include_top=include_top)
    teacher_model, _ = setup_model(args, load=False, include_top=include_top)
    # torch.save(teacher_model.state_dict(), os.path.join(args.result_dir, f'models/{args.model_name.lower().replace("/", "")}.pt'))

    if args.oracle_norm_stats:
        # use freezed normalization statistics loaded from the teacher model
        student_model.pre_featurizer.load_norm_stats(teacher_model.pre_featurizer.model)

    if args.attn_init:
        # initialize the attention layers in the student model using the weights of the teacher model
        initialize_attention_layers(teacher_model.pre_featurizer.model.visual.transformer,
                                   student_model.pre_featurizer.model.visual.transformer)

    teacher_model.set_full_forward(True)
    student_model.set_full_forward(True)

    """
    Vanilla data parallel
    """
    if ngpus_per_node > 1 and not args.distributed:
        student_model = nn.DataParallel(student_model)
        teacher_model = nn.DataParallel(teacher_model)
        student_model.input_resolution = student_model.module.input_resolution
        student_model.pre_featurizer = student_model.module.pre_featurizer
        student_model.featurizer = student_model.module.featurizer
        if include_top:
            student_model.classification_head = student_model.module.classification_head

    if args.distributed:
        if torch.cuda.is_available():
            if args.gpu is not None:
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs of the current node.
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
                student_model = nn.parallel.DistributedDataParallel(student_model, device_ids=[args.gpu])
                # teacher_model = nn.parallel.DistributedDataParallel(teacher_model, device_ids=[args.gpu])
            else:
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                student_model = nn.parallel.DistributedDataParallel(student_model)
                # teacher_model = nn.parallel.DistributedDataParallel(teacher_model)
            
        student_model.input_resolution = student_model.module.input_resolution
        student_model.pre_featurizer = student_model.module.pre_featurizer
        student_model.featurizer = student_model.module.featurizer
        if include_top:
            student_model.classification_head = student_model.module.classification_head

    dataset_class = getattr(datasets, 'ImageNetTrain')
    dataset = dataset_class(get_transform(args, student_model.input_resolution, aug=aug), location=args.data_dir, batch_size=args.batch_size,
                            num_workers=args.num_workers, distributed=args.distributed, pin_memory=args.pin_memory)
    num_batches = len(dataset.train_loader)
    train_loader, train_sampler = get_dataloader(dataset, is_train=True, batch_size=args.batch_size, device=device, image_encoder=None,
                                  distributed=args.distributed, pin_memory=args.pin_memory)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.ls)

    if args.kd:
        criterion_distill = KDLoss(T=args.kd_temperature)
        params = [param for param in student_model.parameters() if param.requires_grad]
        optimizer = optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay)
        scheduler = cosine_lr_with_warmup(optimizer, args.lr, args.warmup_length, args.epochs*num_batches,
                                          args.lr_warm_restarts, args.restart_epochs*num_batches)
    elif args.rd:
        criterion_distill = HintLoss()
        params_feat = [param for param in student_model.pre_featurizer.parameters() if param.requires_grad] \
            + [param for param in student_model.featurizer.parameters() if param.requires_grad]
        # params_clf = [param for param in student_model.classification_head.parameters() if param.requires_grad]
        # params = params_feat + params_clf
        params = params_feat
        optimizer_feat = optim.AdamW(params_feat, lr=args.lr, weight_decay=args.weight_decay)
        # optimizer_clf = optim.AdamW(params_clf, lr=args.lr, weight_decay=args.weight_decay)
        # scheduler = cosine_lr_with_warmup([optimizer_feat, optimizer_clf], args.lr, args.warmup_length, args.epochs*num_batches,
        #                                   args.lr_warm_restarts, args.restart_epochs*num_batches)
        scheduler = cosine_lr_with_warmup(optimizer_feat, args.lr, args.warmup_length, args.epochs*num_batches,
                                          args.lr_warm_restarts, args.restart_epochs*num_batches)
    elif args.crd:
        # hyperparameters used in the contrastive representation distillation paper
        opt = {
            's_dim': student_model.pre_featurizer.output_dim,
            't_dim': teacher_model.pre_featurizer.output_dim,
            'feat_dim': 128,
            'nce_k': 16384,
            'nce_t': 0.07,
            'nce_m': 0.5,
            'n_data': len(dataset.train_dataset)
        }
        criterion_distill = CRDLoss(opt, device)
        params_feat = [param for param in student_model.pre_featurizer.parameters() if param.requires_grad] \
            + [param for param in student_model.featurizer.parameters() if param.requires_grad] \
            + [param for param in criterion_distill.embed_s.parameters() if param.requires_grad] \
            + [param for param in criterion_distill.embed_t.parameters() if param.requires_grad]
        params_clf = [param for param in student_model.classification_head.parameters() if param.requires_grad]
        params = params_feat + params_clf
        optimizer_feat = optim.AdamW(params_feat, lr=args.lr, weight_decay=args.weight_decay)
        optimizer_clf = optim.AdamW(params_clf, lr=args.lr, weight_decay=args.weight_decay)
        scheduler = cosine_lr_with_warmup([optimizer_feat, optimizer_clf], args.lr, args.warmup_length, args.epochs*num_batches,
                                          args.lr_warm_restarts, args.restart_epochs*num_batches)
    else:
        raise NotImplementedError

    args.current_epoch = 0
    # evaluate(student_model, device, args)

    # Train loop
    if args.resume:
        starting_epoch = int(re.findall(r'\d+', args.load_dir)[-1]) + 1
    else:
        starting_epoch = 1
    
    for epoch in range(starting_epoch, args.epochs+starting_epoch):
        if args.distributed:
            train_sampler.set_epoch(epoch-starting_epoch)

        student_model.train()
        for i, batch in enumerate(tqdm(train_loader)):
            start_time = time.time()
            step = i + (epoch - starting_epoch) * num_batches
            if isinstance(scheduler, list):
                [s(step) for s in scheduler]
            else:
                scheduler(step)

            batch = maybe_dictionarize(batch)
            x = batch[input_key].to(device)
            y = batch['labels'].to(device)
            data_time = time.time() - start_time

            if args.kd:
                teacher_logits = teacher_model(x)
                student_logits = student_model(x)
                loss = (1 - args.distill_weight) * criterion(student_logits, y) + args.distill_weight * criterion_distill(student_logits, teacher_logits)
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(params, 1.0)
                optimizer.step()
            elif args.rd:
                if args.attn_distill or args.layer_distill:
                    teacher_feats, teacher_aux_outputs = teacher_model(x)
                    student_feats, student_aux_outputs = student_model(x)
                    loss_feat = args.distill_weight * criterion_distill(student_feats, teacher_feats)
                    loss_attn = 0.
                    loss_layer = 0.
                    if args.attn_distill:
                        teacher_attn_weights = [teacher_aux_output['attn_weights'] for teacher_aux_output in teacher_aux_outputs]
                        student_attn_weights = [student_aux_output['attn_weights'] for student_aux_output in student_aux_outputs]
                        for teacher_attn_weight, student_attn_weight in zip(teacher_attn_weights, student_attn_weights):
                            loss_attn += criterion_distill(student_attn_weight, teacher_attn_weight)
                        loss_attn = args.attn_distill_weight * loss_attn / len(teacher_attn_weights)
                    if args.layer_distill:
                        teacher_hidden_states = [teacher_aux_output['hidden_states'] for teacher_aux_output in teacher_aux_outputs]
                        student_hidden_states = [student_aux_output['hidden_states'] for student_aux_output in student_aux_outputs]
                        for teacher_hidden_state, student_hidden_state in zip(teacher_hidden_states, student_hidden_states):
                            loss_layer += criterion_distill(student_hidden_state, teacher_hidden_state)
                        loss_layer = args.layer_distill_weight * loss_layer / len(teacher_hidden_states)
                    loss = loss_feat + loss_layer + loss_attn
                else:
                    teacher_feats = teacher_model(x)
                    student_feats = student_model(x)
                    loss_feat = args.distill_weight * criterion_distill(student_feats, teacher_feats)
                    # loss_clf = criterion(student_logits, y)
                    # loss = loss_feat + loss_clf
                    loss = loss_feat

                optimizer_feat.zero_grad()
                # loss_feat.backward(retain_graph=True)
                loss_feat.backward()
                nn.utils.clip_grad_norm_(params, 1.0)
                optimizer_feat.step()

                # optimizer_clf.zero_grad()
                # loss_clf.backward()
                # optimizer_clf.step()
            else:
                raise NotImplementedError

            batch_time = time.time() - start_time
            if i % args.print_freq == 0:
                percent_complete = 100 * i / len(train_loader)
                aux_str = ""
                if args.attn_distill or args.layer_distill:
                    aux_str += f"\t Repr loss: {loss_feat.item():.6f}"
                if args.attn_distill:
                    aux_str += f"\t Attn loss: {loss_attn.item():.6f}"
                if args.layer_distill:
                    aux_str += f"\t Layer loss: {loss_layer.item():.6f}"
                print(
                    f"Train Epoch: {epoch}/{args.epochs+starting_epoch-1} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
                    f"LR: {optimizer_feat.param_groups[0]['lr']:.6f}\t Loss: {loss.item():.6f}{aux_str}\t Data (t) {data_time:.3f}\t Batch (t) {batch_time:.3f}",
                    flush=True
                )

        # Save model
        if not args.distributed or (args.distributed and args.rank % ngpus_per_node == 0):
            if epoch % 10 == 0 and args.result_dir is not None:
                model_save_dir = os.path.join(args.result_dir, 'models')
                os.makedirs(model_save_dir, exist_ok=True)
                print('Saving model to', model_save_dir)
                model_path = os.path.join(model_save_dir, f'model_{args.exp_name}_epoch{epoch}.pt')
                if args.kd:
                    torch.save(student_model.state_dict(), model_path)
                elif args.rd:
                    torch.save(student_model.pre_featurizer.state_dict(), model_path)

        # Evaluate
        args.current_epoch = epoch
        # evaluate(gpu, ngpus_per_node, args, model=student_model, input_key='images')


def evaluate(gpu, ngpus_per_node, args, model=None, input_key='features', visualize_predictions=False, ensemble=False, visualize_activation_rates=False):
    if args.eval_datasets is None:
        return
    
    if model is None:
        model, device = setup_model(args, load=True)
    else:
        if gpu is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            device = f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu'
    
    model.eval()
    
    """
    Vanilla data parallel
    """
    if ngpus_per_node > 1 and not (isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel)):
        model = nn.DataParallel(model)
        model.input_resolution = model.module.input_resolution
        model.pre_featurizer = model.module.pre_featurizer
        model.featurizer = model.module.featurizer
        model.classification_head = model.module.classification_head

    if input_key == 'images':
        if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel):
            model.module.set_full_forward(True)
        else:
            model.set_full_forward(True)

    # Compute the L2 norm of featurizer layers
    # if hasattr(model.featurizer, 'mlp'):
        # mlp = model.featurizer.mlp
        # for name, param in mlp.named_parameters():
            # print(name)
            # print(torch.norm(param))
    
    if input_key == 'images':
        image_encoder = None
    else:
        image_encoder = model.pre_featurizer

    info = vars(args)

    if ensemble and args.zeroshot_init:
        print('Using weight-space ensemble.')
        info['ensemble'] = True
        zeroshot_model, _ = setup_model(args, load=False, include_top=True)
        model_eval = copy.deepcopy(model)
        model_eval.eval()
        if args.lp:
            theta_0 = {k: v.clone() for k, v in zeroshot_model.classification_head.state_dict().items()}
            theta_1 = {k: v.clone() for k, v in model_eval.classification_head.state_dict().items()}
            theta = merge(theta_0, theta_1, alpha=0.5)
            model_eval.classification_head.load_state_dict(theta)
        else:
            theta_0 = {k: v.clone() for k, v in zeroshot_model.state_dict().items()}
            theta_1 = {k: v.clone() for k, v in model_eval.state_dict().items()}
            theta = merge(theta_0, theta_1, alpha=0.5)
            model_eval.load_state_dict(theta)
        del zeroshot_model
    else:
        info['ensemble'] = False
        model_eval = model

    for _, dataset_name in enumerate(args.eval_datasets):
        print('========== Evaluating on {} =========='.format(dataset_name))
        dataset_class = getattr(datasets, dataset_name)
        dataset = dataset_class(get_transform(args, model_eval.input_resolution), location=args.data_dir, batch_size=args.batch_size,
                                num_workers=args.num_workers, pin_memory=args.pin_memory)

        eval_loader, _ = get_dataloader(dataset, is_train=False, batch_size=args.batch_size, device=device,
                                        image_encoder=image_encoder, pin_memory=args.pin_memory)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        if hasattr(dataset, 'post_loop_metrics'):
            # keep track of labels, predictions and metadata
            all_labels, all_preds, all_metadata = [], [], []
        
        if visualize_predictions:
            k = 3
            maxlen = 25
            visual_imgs = []
            visual_classnames = []
            visual_topk_preds = []
            figname = '{}-{}'.format(args.exp_name, dataset_name)

        with torch.no_grad():
            top1, correct, n = 0., 0., 0.
            for batch_id, batch in enumerate(tqdm(eval_loader)):
                batch = maybe_dictionarize(batch)
                x = batch[input_key].to(device)
                y = batch['labels'].to(device)

                if 'image_paths' in batch:
                    image_paths = batch['image_paths']

                logits = model_eval(x)
                projection_fn = getattr(dataset, 'project_logits', None)
                if projection_fn is not None:
                    logits = projection_fn(logits, device)

                if hasattr(dataset, 'project_labels'):
                    y = dataset.project_labels(y, device)

                pred = logits.argmax(dim=1, keepdim=True).to(device)
                if visualize_predictions:
                    topk_probs, topk_labels = torch.topk(torch.softmax(logits, dim=1), k=k, dim=1)
                
                if hasattr(dataset, 'accuracy'):
                    acc1, mask = dataset.accuracy(logits, y, image_paths, args)
                    correct += acc1
                    n += len(mask)
                else:
                    mask = pred.eq(y.view_as(pred))
                    correct += mask.sum().item()
                    n += y.size(0)
                    mask = mask.cpu().numpy().squeeze()

                if hasattr(dataset, 'post_loop_metrics'):
                    all_labels.append(y.cpu().clone().detach())
                    all_preds.append(logits.cpu().clone().detach())
                    metadata = batch['metadata'] if 'metadata' in batch else image_paths
                    all_metadata.extend(metadata)
                
                # get incorrectly predicted images
                if visualize_predictions:
                    mask = ~mask
                    wrong_imgs = batch['images'].cpu().numpy()[mask]
                    wrong_imgs = unnormalize(np.moveaxis(wrong_imgs, 1, 3), args.model_name)
                    wrong_labels = batch['labels'].cpu().numpy()[mask]
                    wrong_topk_probs = topk_probs[mask]
                    wrong_topk_labels = topk_labels[mask]

                    for i in range(len(wrong_imgs)):
                        append_or_replace([visual_imgs, visual_classnames, visual_topk_preds],
                                        [wrong_imgs[i], dataset.classnames[wrong_labels[i]],
                                                {'classnames': [dataset.classnames[j] for j in wrong_topk_labels[i]],
                                                'probs': wrong_topk_probs[i]}], maxlen, replace_prob=1-batch_id/len(eval_loader))
            top1 = correct / n

            if hasattr(dataset, 'post_loop_metrics'):
                all_labels = torch.cat(all_labels)
                all_preds = torch.cat(all_preds)
                metrics = dataset.post_loop_metrics(all_labels, all_preds, all_metadata, args)
                if 'acc' in metrics:
                    metrics['top1'] = metrics['acc']
            else:
                metrics = {}

        if 'top1' not in metrics:
            metrics['top1'] = top1

        print('{} Top-1 acc: {:.4f}'.format(dataset_name, metrics['top1']))

        for key, val in metrics.items():
            if 'worst' in key or 'f1' in key.lower() or 'pm0' in key or 'pm10' in key:
                print('{} {}: {:.4f}'.format(dataset_name, key, val))
            info[dataset_name + ':' + key] = val
        
        if visualize_predictions:
            visualize_imgs_and_preds(visual_imgs, visual_classnames, visual_topk_preds,
                filename=os.path.join(args.result_dir, 'wrong_predictions', figname))
    
        if visualize_activation_rates:
            plot_activation_rates(model, dataset_name, args.model_name.split('_')[-1], save_path=os.path.join(args.result_dir, 'activation_rates', dataset_name))
            # plot_activation_rates(model, dataset_name, args.model_name, save_path=os.path.join(args.result_dir, 'activation_rates', dataset_name))

    if args.result_dir is not None:
        os.makedirs(args.result_dir, exist_ok=True)
        results_filename = os.path.join(args.result_dir, f'results_{args.exp_name}.txt')
        with open(results_filename, 'a+') as f:
            f.write(json.dumps(info, indent=2) + '\n')
    
    return info


if __name__ == '__main__':
    main()
