from collections import defaultdict, deque
import datetime
import errno
import logging
import math
import models
import numpy as np
import os
import signal
import socket
import subprocess
import sys
import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data.dataloader import default_collate
import torchvision
import transforms as T


from datasets.Kinetics400 import Kinetics400
from datasets.UCF101 import UCF101
from datasets.HMDB51 import HMDB51
from datasets.AVideoDataset import AVideoDataset,AVideoDataset2


def dist_collect(x):
    """ collect all tensor from all GPUs
    args:
        x: shape (mini_batch, ...)
    returns:
        shape (mini_batch * num_gpu, ...)
    """
    x = x.contiguous()
    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype)
                for _ in range(dist.get_world_size())]
    dist.all_gather(out_list, x)
    return torch.cat(out_list, dim=0)


def dist_collect_other(x, return_before_cat=False):
    """ collect all tensor from all GPUs except current one
    args:
        x: shape (mini_batch, ...)
    returns:
        shape (mini_batch * num_gpu, ...)
    """
    x = x.contiguous()
    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype)
                for _ in range(dist.get_world_size())]
    dist.all_gather(out_list, x)
    # get only non local ones.
    out_list = [out_list[rank]
                for rank in range(dist.get_world_size()) if rank != dist.get_rank()]
    if return_before_cat: 
        return out_list
    return torch.cat(out_list, dim=0)


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def SIGTERMHandler(a, b):
    print('received sigterm')
    pass


def signalHandler(a, b):
    print('Signal received', a, time.time(), flush=True)
    os.environ['SIGNAL_RECEIVED'] = 'True'
    return


def init_signal_handler():
    """
    Handle signals sent by SLURM for time limit / pre-emption.
    """
    os.environ['SIGNAL_RECEIVED'] = 'False'
    os.environ['MAIN_PID'] = str(os.getpid())

    signal.signal(signal.SIGUSR1, signalHandler)
    signal.signal(signal.SIGTERM, SIGTERMHandler)
    print("Signal handler installed.", flush=True)


def trigger_job_requeue(checkpoint_filename):
    ''' Submit a new job to resume from checkpoint.
        Be careful to use only for main process.
    '''
    print("IN JOB REQUEUE FUNCTION")
    print(checkpoint_filename)
    if int(os.environ['SLURM_PROCID']) == 0 and \
            str(os.getpid()) == os.environ['MAIN_PID'] and os.path.isfile(checkpoint_filename):
        print('time is up, back to slurm queue', flush=True)
        command = 'scontrol requeue ' + os.environ['SLURM_JOB_ID']
        print(command)
        if os.system(command):
            raise RuntimeError('requeue failed')
        print('New job submitted to the queue', flush=True)
    exit(0)


def restart_from_checkpoint(args, ckp_path=None, run_variables=None, **kwargs):
    """
    Re-start from checkpoint present in experiment repo
    """
    if ckp_path is None:
        ckp_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')

    print(f'Ckpt path: {ckp_path}', flush=True)

    # look for a checkpoint in exp repository
    if not os.path.isfile(ckp_path):
        return

    print('Found checkpoint in experiment repository', flush=True)

    # open checkpoint file
    map_location = None
    if args.world_size > 1:
        map_location = "cuda:" + str(args.local_rank)
    checkpoint = torch.load(ckp_path, map_location=map_location)

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            if key == 'model':
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in checkpoint[key].items():
                    name = 'module.' + k  # adding module because we load after doing distributed
                    new_state_dict[name] = v
                value.load_state_dict(new_state_dict)
            elif key == 'model_ema':
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in checkpoint[key].items():
                    name =  k  # no need for adding module because it's not distributed.
                    new_state_dict[name] = v
                value.load_state_dict(new_state_dict)
            else:
                value.load_state_dict(checkpoint[key])
            print("=> loaded {} from checkpoint '{}'"
                        .format(key, ckp_path))
        else:
            print("=> failed to load {} from checkpoint '{}'"
                        .format(key, ckp_path))

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]


def save_checkpoint(args, epoch, model, optimizer, lr_scheduler, selflabels=None, ckpt_freq=10):
    checkpoint = {
        'model': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch + 1,
        'args': args
    }
    if selflabels is not None:
        checkpoint['selflabels'] = selflabels
    mkdir(os.path.join(args.output_dir, 'model_weights'))
    mkdir(os.path.join(args.output_dir, 'checkpoints'))
    if epoch % 10 == 0:
        save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'model_weights', f'model_{epoch}.pth'.format(epoch))
        )
    save_on_master(
        checkpoint,
        os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
    )
    if epoch % ckpt_freq == 0:
        save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'checkpoints', f'ckpt_{epoch}.pth')
        )
    if args.global_rank == 0:
        print(f'Saving checkpoint to: {args.output_dir}', flush=True)
        print(f'Checkpoint saved', flush=True)


def init_distributed_mode(params, make_communication_groups=False):
    """
    Handle single and multi-GPU / multi-node / SLURM jobs.
    Initialize the following variables:
        - n_nodes
        - node_id
        - local_rank
        - global_rank
        - world_size
    """
    params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm
    print("SLURM job: %s" % str(params.is_slurm_job))

    # SLURM job
    if params.is_slurm_job and not params.debug_slurm:

        assert params.local_rank == -1   # on the cluster, handled by SLURM

        SLURM_VARIABLES = [
            'SLURM_JOB_ID',
            'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS',
            'SLURM_TASKS_PER_NODE',
            'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU',
            'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID'
        ]

        PREFIX = "%i - " % int(os.environ['SLURM_PROCID'])
        for name in SLURM_VARIABLES:
            value = os.environ.get(name, None)
            print(PREFIX + "%s: %s" % (name, str(value)))

        # # job ID
        params.job_id = os.environ['SLURM_JOB_ID']

        # number of nodes / node ID
        params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
        params.node_id = int(os.environ['SLURM_NODEID'])

        # local rank on the current node / global rank
        params.local_rank = int(os.environ['SLURM_LOCALID'])
        params.global_rank = int(os.environ['SLURM_PROCID'])

        # number of processes / GPUs per node
        params.world_size = int(os.environ['SLURM_NTASKS'])
        params.n_gpu_per_node = params.world_size // params.n_nodes

        # define master address and master port
        hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames',
            os.environ['SLURM_JOB_NODELIST']])
        params.master_addr = hostnames.split()[0].decode('utf-8')
        params.master_port = 19500

        assert 10001 <= params.master_port <= 20000 or params.world_size == 1
        print(PREFIX + "Master address: %s" % params.master_addr)
        print(PREFIX + "Master port   : %i" % params.master_port)

        # set environment variables for 'env://'
        os.environ['MASTER_ADDR'] = str(params.master_addr)
        os.environ['MASTER_PORT'] = str(params.master_port)
        os.environ['WORLD_SIZE'] = str(params.world_size)
        os.environ['RANK'] = str(params.global_rank)

    # multi-GPU job (local/multi-node) - started with torch.distributed.launch
    elif params.local_rank != -1:

        assert params.master_port == -1

        # read environment variables
        params.global_rank = int(os.environ['RANK'])
        params.world_size = int(os.environ['WORLD_SIZE'])
        params.n_gpu_per_node = int(os.environ['NGPU'])

        # number of nodes / node ID
        params.n_nodes = params.world_size // params.n_gpu_per_node
        params.node_id = params.global_rank // params.n_gpu_per_node

    # local job (single GPU)
    else:
        assert params.local_rank == -1
        assert params.master_port == -1
        params.n_nodes = 1
        params.node_id = 0
        params.local_rank = 0
        params.global_rank = 0
        params.world_size = 1
        params.n_gpu_per_node = 1

    # sanity checks
    assert params.n_nodes >= 1
    assert 0 <= params.node_id < params.n_nodes
    assert 0 <= params.local_rank <= params.global_rank < params.world_size
    assert params.world_size == params.n_nodes * params.n_gpu_per_node

    # define whether this is the master process / if we are in distributed mode
    params.is_master = params.node_id == 0 and params.local_rank == 0
    params.multi_node = params.n_nodes > 1
    params.multi_gpu = params.world_size > 1

    # summary
    PREFIX = "%i - " % params.global_rank
    print(PREFIX + "Number of nodes: %i" % params.n_nodes)
    print(PREFIX + "Node ID        : %i" % params.node_id)
    print(PREFIX + "Local rank     : %i" % params.local_rank)
    print(PREFIX + "Global rank    : %i" % params.global_rank)
    print(PREFIX + "World size     : %i" % params.world_size)
    print(PREFIX + "GPUs per node  : %i" % params.n_gpu_per_node)
    print(PREFIX + "Master         : %s" % str(params.is_master))
    print(PREFIX + "Multi-node     : %s" % str(params.multi_node))
    print(PREFIX + "Multi-GPU      : %s" % str(params.multi_gpu))
    print(PREFIX + "Hostname       : %s" % socket.gethostname())

    # set GPU device
    torch.cuda.set_device(params.local_rank)

    # initialize multi-GPU
    if params.multi_gpu:
        params.distributed = True

        # 'env://' will read these environment variables:
        # MASTER_PORT - required; has to be a free port on machine with rank 0
        # MASTER_ADDR - required (except for rank 0); address of rank 0 node
        # WORLD_SIZE - required; can be set either here, or in a call to init fn
        # RANK - required; can be set either here, or in a call to init function

        print("Initializing PyTorch distributed ...")
        torch.distributed.init_process_group(
            init_method='env://',
            backend='nccl',
            rank=params.global_rank,
            world_size=params.world_size,
        )
        print("Initialized!")

        if make_communication_groups:
            params.super_classes = 1
            params.training_local_world_size = params.world_size // params.super_classes
            params.training_local_rank = params.global_rank % params.training_local_world_size
            params.training_local_world_id = params.global_rank // params.training_local_world_size

            # prepare training groups
            
            training_groups = []
            for group_id in range(params.super_classes):
                ranks = [params.training_local_world_size * group_id + i \
                        for i in range(params.training_local_world_size)]
                training_groups.append(dist.new_group(ranks=ranks))
            return training_groups


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def _get_cache_path(dataset, mode, fold, clip_len, steps_between_clips):
    import hashlib
    filepath = str(dataset) + str(mode) + str(fold) + str(clip_len) + str(steps_between_clips)
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", dataset, h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0, args=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    '''
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB), logger=logger
                    )
                    '''
                else:
                    print_or_log(log_msg.format(
                        i,
                        len(iterable),
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))


class MetricLoggerFinetune(MetricLogger):
    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0, args=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    if args.global_rank == 0:
                        print_or_log(log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time), data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB), logger=logger
                        )
                else:
                    print_or_log(log_msg.format(
                        i, 
                        len(iterable), 
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))


class MetricLoggerGDT(MetricLogger):
    def __init__(self, delimiter):
        super().__init__(delimiter)

    def log_every(self, iterable, print_freq, header=None, logger=None, writer=None, mode='train', epoch=0):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for idx, obj in enumerate(iterable):
            data_time.update(time.time() - end)
            yield idx, obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    pass
                    '''
                    print_or_log(log_msg.format(
                        i, 
                        len(iterable), 
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB), logger=logger)
                    '''
                else:
                    print_or_log(log_msg.format(
                        i, 
                        len(iterable), 
                        eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)), logger=logger)
                if writer:
                    step = epoch * len(iterable) + i
                    for key in self.meters:
                        writer.add_scalar(
                            f'{mode}/{key}/iter', 
                            self.meters[key].avg, 
                            step
                        )
                    writer.add_scalar(
                        f'{mode}/memory/iter', 
                        torch.cuda.max_memory_allocated() / MB, 
                        step
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print_or_log('{} Total time: {}'.format(header, total_time_str), logger=logger)

    

class MyFormatter(logging.Formatter):

    err_fmt = "%(asctime)s %(name)s %(module)s: %(lineno)d: %(levelname)s: %(msg)s"
    dbg_fmt = "%(asctime)s %(module)s: %(lineno)d: %(levelname)s:: %(msg)s"
    info_fmt = "%(msg)s"

    def __init__(self):
        super().__init__(fmt="%(asctime)s %(name)s %(levelname)s: %(message)s",
                         datefmt=None,
                         style='%')

    def format(self, record):

        # Save the original format configured by the user
        # when the logger formatter was instantiated
        format_orig = self._style._fmt

        # Replace the original format with one customized by logging level
        if record.levelno == logging.DEBUG:
            self._style._fmt = MyFormatter.dbg_fmt

        elif record.levelno == logging.INFO:
            self._style._fmt = MyFormatter.info_fmt

        elif record.levelno == logging.ERROR:
            self._style._fmt = MyFormatter.err_fmt

        # Call the original formatter class to do the grunt work
        result = logging.Formatter.format(self, record)

        # Restore the original format configured by the user
        self._style._fmt = format_orig

        return result


def setup_logger(name, save_dir, is_master, logname="run.log"):
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    # don't log results for the non-master process
    if not is_master:
        return logger
    ch = logging.StreamHandler(stream=sys.stdout)
    formatter = MyFormatter()
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    print("Creating logger save dir")
    if save_dir:
        fh = logging.FileHandler(os.path.join(save_dir, logname))
        fh.setFormatter(formatter)
        logger.addHandler(fh)
    print(f"Finished creating logger: {save_dir}")

    return logger


def print_or_log(message, logger=None):
    if logger is None:
        print(message, flush=True)
    else:
        logger.info(message)


def setup_tbx(save_dir, is_master):
    from torch.utils.tensorboard import SummaryWriter

    if not is_master:
        return None

    writer = SummaryWriter(save_dir)
    return writer


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


### MODELS
def load_model(
    model_name='r3d_18', 
    vid_base_arch='r2plus1d_18', 
    aud_base_arch='resnet18', 
    pretrained=False,
    norm_feat=True,
    use_mlp=False,
    mlptype=0,
    headcount=1,
    num_classes=256,
    use_max_pool=False
):
    print(f"Loading {model_name}: {vid_base_arch} and {aud_base_arch}, using MLP head: {use_mlp}", flush=True)
    if model_name in ['r3d_18', 'mc3_18', 'r2plus1d_18']:
        print(f"Loading {model_name}, pretrained: {pretrained}")
        model = torchvision.models.video.__dict__[model_name](pretrained=pretrained)
        return model
    else:
        model = models.AV_GDT(
            vid_base_arch=vid_base_arch, 
            aud_base_arch=aud_base_arch, 
            pretrained=pretrained,
            norm_feat=norm_feat,
            use_mlp=use_mlp,
            mlptype=mlptype,
            headcount=headcount,
            num_classes=num_classes,
            use_max_pool=use_max_pool
        )
        return model


def load_model_parameters(model, model_weights):
    loaded_state = model_weights
    self_state = model.state_dict()
    for name, param in loaded_state.items():
        param = param
        if 'module.' in name:
            name = name.replace('module.', '')
        if name in self_state.keys():
            self_state[name].copy_(param)
        else:
            print("didnt load ", name)


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def str2bool(v):
    v = v.lower()
    if v in ('yes', 'true', 't', '1'):
        return True
    elif v in ('no', 'false', 'f', '0'):
        return False
    raise ValueError('Boolean argument needs to be true or false. '
        'Instead, it is %s.' % v)


def compute_metrics(y, pred, num_classes=10):
    """
    Compute perfomance metrics given the predicted labels and the true labels
    Args:
        y: True label vector
           (Type: np.ndarray)
        pred: Predicted label vector
              (Type: np.ndarray)
    Returns:
        metrics: Metrics dictionary
                 (Type: dict[str, *])
    """
    # Make sure everything is a numpy array
    if isinstance(y, torch.Tensor):
        y = y.cpu().data.numpy()
    elif not isinstance(y, np.ndarray):
        y = np.array(y)
    if isinstance(pred, torch.Tensor):
        pred = pred.cpu().data.numpy()
    elif not isinstance(pred, np.ndarray):
        pred = np.array(pred)
    assert isinstance(y, np.ndarray)
    assert isinstance(pred, np.ndarray)

    # Convert from one-hot to integer encoding if necessary
    if y.ndim == 2:
        y = np.argmax(y, axis=1)
    if pred.ndim == 2:
        pred = np.argmax(pred, axis=1)
    assert y.ndim == 1
    assert pred.ndim == 1

    acc = (y == pred).mean()

    class_acc = []
    for class_idx in range(num_classes):
        idxs = (y == class_idx)
        class_acc.append((y[idxs] == pred[idxs]).mean())

    ave_class_acc = np.mean(class_acc)

    return {
        'accuracy': acc,
        'class_accuracy': class_acc,
        'average_class_accuracy': ave_class_acc
    }


def load_dataset(
    dataset_name='kinetics',
    mode='train',
    fold=1,
    frames_per_clip=30,
    transforms=None,
    clips_per_video=1,
    num_data_samples=None,
    sampletime=1.2,
    subsample=False,
    seed=0,
    model='avc',
    sample_aud_ind=False,
    sample_rate=1,
    train_crop_size=112,
    colorjitter=False,
    dualdata=False,
    synced=True,
    temp_jitter=True,
    center_crop=False,
    target_fps=30,
    aug_audio=[],
    num_sec=1,
    aud_sample_rate=48000,
    aud_spec_type=1,
    use_volume_jittering=False,
    use_temporal_jittering=False,
    z_normalize=False,
):
    if dataset_name in ['kinetics', 'kinetics600', 'audioset', 'vggsound']:
        print("Loading Kinetics custom dataset", flush=True)
        if not dualdata:
            dataset = AVideoDataset(
                ds_name=dataset_name,
                mode=mode,
                num_frames=frames_per_clip,
                seed=seed,
                sample_rate=sample_rate,
                train_crop_size=train_crop_size,
                num_data_samples=num_data_samples,
                colorjitter=colorjitter,
                temp_jitter=temp_jitter,
                center_crop=center_crop,
                aug_audio=aug_audio,
                target_fps=target_fps,
                num_sec=num_sec,
                aud_sample_rate=aud_sample_rate,
                aud_spec_type=aud_spec_type,
                use_volume_jittering=use_volume_jittering,
                use_temporal_jittering=use_temporal_jittering,
                z_normalize=z_normalize
            )
        else:
            print("Getting dual dataset")
            dataset = AVideoDataset2(
                ds_name=dataset_name,
                mode=mode,
                num_frames=frames_per_clip,
                seed=seed,
                sample_rate=sample_rate,
                train_crop_size=train_crop_size,
                num_data_samples=num_data_samples,
                colorjitter=colorjitter,
                synced=synced,
                target_fps=target_fps,
                aug_audio=aug_audio,
                num_sec=num_sec,
                aud_sample_rate=aud_sample_rate,
                aud_spec_type=aud_spec_type,
                use_volume_jittering=use_volume_jittering,
                use_temporal_jittering=use_temporal_jittering,
                z_normalize=z_normalize,
            )
        return dataset
    elif dataset_name == 'ucf101':
        dataset = UCF101(
            frames_per_clip=frames_per_clip,
            step_between_clips=1,
            transform=transforms,
            fold=fold,
            subsample=subsample,
            train=True if mode == 'train' else False
        )
        return dataset
    elif dataset_name == 'hmdb51':
        dataset = HMDB51(
            frames_per_clip=frames_per_clip,
            step_between_clips=1,
            transform=transforms,
            fold=fold,
            subsample=subsample,
            train=True if mode == 'train' else False
        )
        return dataset
    elif dataset_name == 'kinetics400':
        dataset = Kinetics400(
            frames_per_clip=frames_per_clip,
            step_between_clips=1,
            transform=transforms,
            train=True if mode == 'train' else False
        )
        return dataset
    else:
        assert ("Dataset is not supported")


def correct_hmdb51_name(vid_path):
    vid_name = vid_path.split('/')[-1].strip('.mp4').strip('.avi')
    vid_name = ''.join(x for x in vid_name if x not in '()&;') 
    return vid_name


def collate_fn(batch):
    batch = [(d[0], d[1], d[2], d[3], d[4]) for d in batch if d is not None]
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)


def load_optimizer(name, params, lr=1e-4, momentum=0.9, weight_decay=0,model=None):
    if name == 'sgd':
        optimizer = torch.optim.SGD(params, 
            lr=lr, 
            momentum=momentum, 
            weight_decay=weight_decay
        )
    elif name == 'adam':
        print("Loading Adam Optimizer")
        optimizer = torch.optim.Adam(params, 
            lr=lr, 
            weight_decay=weight_decay
        )
    elif name == 'lbfgs':
        optimizer = torch.optim.LBFGS(model.parameters(),
            lr=lr,
            max_iter=10000
        )
    elif name == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )
    elif name == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=lr,
                                      weight_decay=weight_decay,

        )
    elif name == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                      lr=lr,
                                      weight_decay=weight_decay,
                                      momentum=momentum,
                                      )
    else:
        assert("Only 'adam' and 'sgd' supported")
    return optimizer


def _warmup_batchnorm(args, model, dataset, device, batches=100):
    """
    Run some batches through all parts of the model to warmup the running
    stats for batchnorm layers.
    """
    print("Warming up batchnorm", flush=True)

    # Create train sampler
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    
    # Create dataloader
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=args.batch_size,
        sampler=train_sampler, 
        shuffle=False if train_sampler else True,
        num_workers=args.workers,
        collate_fn=None,
        pin_memory=True
    )
    
    # Put model in train mode
    model.train()

    # Iterate over dataloader batches times 
    for i, q in enumerate(data_loader):
        video, audio, _, _, _ = q
        if i == batches:
            break
        if args.global_rank == 0:
            print((i, video.shape), flush=True)
        video = video.to(device)
        audio = audio.to(device)

        # Forward pass: get features, compute loss and accuracy
        _ = model(video, audio)
    if args.distributed:
        dist.barrier()
    print("Finshed warming up batchnorm", flush=True)


def get_transforms(args):
    normalize = T.Normalize(
        mean=[0.43216, 0.394666, 0.37645],
        std=[0.22803, 0.22145, 0.216989]
    )
    normalize = T.Normalize(
        mean=[0.45, 0.45, 0.45],
        std=[0.225, 0.225, 0.225]
    )
    subsample = False
    if args.augtype == 1:
        print("using augmentation type 1: (resize 128,171 ->crop112)")
        if args.colorjitter and args.use_scale_jittering:
            print("Using Color and Multi-Scale Jittering")
            transform_train = torchvision.transforms.Compose([
                T.ToFloatTensorInZeroOne(),
                T.RandomSECrop(128, 160),
                T.RandomHorizontalFlip(),
                T.ColorJitter(0.4,0.4,0.4) if args.dataset == 'hmdb51' else T.ColorJitter(1,1,1),
                normalize,
                T.RandomCrop((128, 128))
            ])
        elif args.colorjitter:
            print("Using Color Jittering, No Multi-Scale Jittering")
            transform_train = torchvision.transforms.Compose([
                T.ToFloatTensorInZeroOne(),
                T.Resize((128, 174)),
                T.RandomHorizontalFlip(),
                T.ColorJitter(0.4,0.4,0.4) if args.dataset == 'hmdb51' else T.ColorJitter(1,1,1),
                normalize,
                T.RandomCrop((128, 128))
            ])
        elif args.use_scale_jittering:
            print("Using Multi-Scale Jittering, no Color Jittering")
            transform_train = torchvision.transforms.Compose([
                T.ToFloatTensorInZeroOne(),
                T.RandomSECrop(128, 160),
                T.RandomHorizontalFlip(),
                normalize,
                T.RandomCrop((128, 128))
            ])
        else:
            print("No Color or Multi-Scale Jittering")
            transform_train = torchvision.transforms.Compose([
                T.ToFloatTensorInZeroOne(),
                T.Resize((128, 174)),
                T.RandomHorizontalFlip(),
                normalize,
                T.RandomCrop((112, 112))
            ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 174)),
            normalize,
            T.RandomCrop((112, 112))
        ])
        if args.use_scale_jittering:
            transform_test = torchvision.transforms.Compose([
                T.ToFloatTensorInZeroOne(),
                T.RandomSECrop(150, 150),
                normalize,
                T.CenterCrop((128, 128))
            ])
    elif args.augtype in [2,3]: # augtype = 2,3:
        print("using augmentation type 2: (resize256->crop224). setting clip-len to 8!")
        # note that 8x224x224 = 400k, i.e. the same as 32x112x112
        '''
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize(256),
            T.RandomHorizontalFlip(),
            normalize,
            T.RandomCrop((224, 224))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize(256),
            normalize,
            T.CenterCrop((224, 224))
        ])
        '''
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((256, 320)),
            T.RandomHorizontalFlip(),
            normalize,
            T.RandomCrop((112*2, 112*2))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((256, 320)),
            normalize,
            T.CenterCrop((112*2, 112*2))
        ])
        if args.augtype == 2:
            args.clip_len = 8
        if args.augtype == 3:
            args.clip_len = 32
    elif args.augtype == 0:
        print("using augmentation type 1: (resize 128,171 ->Cemtercrop112)")
        transform_train = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            normalize,
            T.CenterCrop((112, 112))
        ])
        transform_test = torchvision.transforms.Compose([
            T.ToFloatTensorInZeroOne(),
            T.Resize((128, 171)),
            normalize,
            T.CenterCrop((112, 112))
        ])
    return transform_train, transform_test, subsample


def get_ds(args, epoch):
    # Getting transforms
    transform_train, transform_test, subsample = get_transforms(args)

    print("Loading data", flush=True)
    st = time.time()
    dataset = load_dataset(
        dataset_name=args.dataset,
        fold=args.fold,
        mode='train',
        frames_per_clip=args.clip_len,
        transforms=transform_train,
        subsample=subsample,
        clips_per_video=args.clips_per_video,
        num_data_samples=args.num_data_samples,
        seed=epoch,
        model=args.model,
        sample_aud_ind=args.sample_aud_ind,
        sample_rate=args.sample_rate,
        train_crop_size=args.train_crop_size,
        colorjitter=args.colorjitter,
        dualdata=args.dualdata,
        synced=args.asynced==0,
        temp_jitter=args.use_temp_jitter,
        center_crop=args.center_crop,
        target_fps=args.target_fps,
        aug_audio=args.aug_audio,
        num_sec=args.num_sec,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_temporal_jittering,
        z_normalize=args.z_normalize,
    )
    print(f"Took {time.time() - st}", flush=True)
    return dataset


def get_dataloader(args, epoch):
    # Getting transforms
    transform_train, transform_test, subsample = get_transforms(args)

    print("Loading data", flush=True)
    st = time.time()
    dataset = load_dataset(
        dataset_name=args.dataset,
        fold=args.fold,
        mode='train',
        frames_per_clip=args.clip_len,
        transforms=transform_train,
        subsample=subsample,
        clips_per_video=args.clips_per_video,
        num_data_samples=args.num_data_samples,
        seed=epoch,
        model=args.model,
        sample_aud_ind=args.sample_aud_ind,
        sample_rate=args.sample_rate,
        train_crop_size=args.train_crop_size,
        colorjitter=args.colorjitter,
        dualdata=args.dualdata,
        synced=args.asynced==0,
        temp_jitter=args.use_temp_jitter,
        center_crop=args.center_crop,
        target_fps=args.target_fps,
        aug_audio=args.aug_audio,
        num_sec=args.num_sec,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_temporal_jittering,
        z_normalize=args.z_normalize,
    )
    print(f"Took {time.time() - st}", flush=True)

    print("Creating data loaders", flush=True)
    train_sampler = None 
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 
        train_sampler.set_epoch(epoch)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=args.batch_size,
        sampler=train_sampler, 
        num_workers=args.workers,
        pin_memory=True, 
        collate_fn=collate_fn,
        drop_last=True
    )
    return dataset, data_loader


# gets list of tensors from concat tensor
def get_list_of_tensors(cat_tensor, num_heads):
    out_list = []
    step = int(len(cat_tensor) / num_heads)
    for i in range(0, len(cat_tensor), step):
        out_list.append(cat_tensor[i: i + step])
    return out_list


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input_tensor):
        return input_tensor.view(input_tensor.size(0), -1)


def reduce_negatives(f1, f2, num_neg):
    rnd_indices = np.random.choice(len(f1), num_neg)
    f1 = f1[rnd_indices]
    f2 = f2[rnd_indices]
    return f1,f2

def get_hyp_str(args):
    hyp = 'base'
    if args.asynced == -1:
        hyp = 'vasynced'
    elif args.asynced == 1:
        hyp = 'iasynced'
    elif args.arrowtime == -1:
        hyp = 'vtime'
    elif args.arrowtime == 1:
        hyp = 'itime'
    return hyp


class PermuteIter():
    """Iterator object that helps with on-GPU dataset"""
    def __init__(self, permutation, bs):
        self.len = len(permutation) // bs
        self.iter = permutation
        self.bs = bs

    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n <= self.len:
            result = self.iter[self.n:self.n + self.bs]
            self.n += self.bs
            return result
        else:
            raise StopIteration

    def __len__(self):
        return self.len
