import os
import math
import torch
import numpy as np
import torch.distributed as dist
from collections import OrderedDict
from timm.utils import get_state_dict
try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None


def load_ema_checkpoint(config, model_ema, logger):
    logger.info(
        f'==============> Resuming form {config.MODEL.RESUME}....................'
    )
    if config.MODEL.RESUME.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
                                                        map_location='cpu',
                                                        check_hash=True)
    else:
        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')

    assert isinstance(checkpoint, dict)
    if 'model_ema' in checkpoint:
        new_state_dict = OrderedDict()
        for k, v in checkpoint['model_ema'].items():
            if model_ema.ema_has_module:
                name = 'module.' + k if not k.startswith('module') else k
            else:
                name = k
            new_state_dict[name] = v
        msg = model_ema.ema.load_state_dict(new_state_dict, strict=False)
        logger.info(msg)
        logger.info('Loaded state_dict_ema')
    else:
        logger.warning(
            'Failed to find state_dict_ema, starting from loaded model weights'
        )

    max_accuracy_ema = 0
    if 'max_accuracy_ema' in checkpoint:
        max_accuracy_ema = checkpoint['max_accuracy_ema']
    if 'ema_decay' in checkpoint:
        model_ema.decay = checkpoint['ema_decay']
    return max_accuracy_ema


def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
    logger.info(
        f'==============> Resuming form {config.MODEL.RESUME}....................'
    )
    if config.MODEL.RESUME.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
                                                        map_location='cpu',
                                                        check_hash=True)
    else:
        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')

    print('resuming model')
    if 'model' in checkpoint:
        model_checkpoint = checkpoint['model']
    elif 'state_dict' in checkpoint:
        model_checkpoint = checkpoint['state_dict']
    else:
        raise AssertionError

    msg = model.load_state_dict(model_checkpoint, strict=False)
    logger.info(msg)
    max_accuracy = 0.0
    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
        if optimizer is not None:
            print('resuming optimizer')
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print('resume optimizer failed')
        if lr_scheduler is not None:
            print('resuming lr_scheduler')
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        config.defrost()
        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
        config.freeze()
        if 'amp' in checkpoint and config.AMP_OPT_LEVEL != 'O0' and checkpoint[
                'config'].AMP_OPT_LEVEL != 'O0':
            scaler.load_state_dict(checkpoint['amp'])
        logger.info(
            f"=> loaded successfully {config.MODEL.RESUME} (epoch {checkpoint['epoch']})"
        )
        if 'max_accuracy' in checkpoint:
            max_accuracy = checkpoint['max_accuracy']

    del checkpoint
    torch.cuda.empty_cache()

    return max_accuracy


def load_pretrained(config, model, logger):
    logger.info(
        f'==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......'
    )
    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')

    state_dict = checkpoint
    if 'model' in checkpoint:
        state_dict = checkpoint['model']
    elif 'module' in checkpoint:
        state_dict = checkpoint['module']
    elif 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']

    first_key = list(state_dict.keys())[0]
    # delete teacher weights
    if 'student' in first_key or 'teacher' in first_key:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'student_proj' in k:
                continue
            if 'student' in k:
                new_k = k.replace('student.', '')
                new_state_dict[new_k] = v
        state_dict = new_state_dict

    # weights from sim
    if 'mask_token' in first_key:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'mm_dcnv3' in k:
                continue
            if 'dcnv3' not in k and 'clip_projector' not in k:
                continue
            new_k = k.replace('dcnv3.', '')
            new_state_dict[new_k] = v
        new_state_dict['fc_norm.weight'] = state_dict[
            'clip.classifier_ln.weight']
        new_state_dict['fc_norm.bias'] = state_dict['clip.classifier_ln.bias']
        new_state_dict['head.weight'] = state_dict['clip.classifier.weight']
        new_state_dict['head.bias'] = state_dict['clip.classifier.bias']
        state_dict = new_state_dict

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [
        k for k in state_dict.keys() if 'relative_position_index' in k
    ]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete relative_coords_table since we always re-init it
    relative_position_index_keys = [
        k for k in state_dict.keys() if 'relative_coords_table' in k
    ]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k]
    for k in attn_mask_keys:
        del state_dict[k]

    # bicubic interpolate relative_position_bias_table if not match
    relative_position_bias_table_keys = [
        k for k in state_dict.keys() if 'relative_position_bias_table' in k
    ]
    for k in relative_position_bias_table_keys:
        relative_position_bias_table_pretrained = state_dict[k]
        relative_position_bias_table_current = model.state_dict()[k]
        L1, nH1 = relative_position_bias_table_pretrained.size()
        L2, nH2 = relative_position_bias_table_current.size()
        if nH1 != nH2:
            logger.warning(f'Error in loading {k}, passing......')
        else:
            if L1 != L2:
                # bicubic interpolate relative_position_bias_table if not match
                S1 = int(L1**0.5)
                S2 = int(L2**0.5)
                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                    relative_position_bias_table_pretrained.permute(1, 0).view(
                        1, nH1, S1, S1),
                    size=(S2, S2),
                    mode='bicubic')
                state_dict[
                    k] = relative_position_bias_table_pretrained_resized.view(
                        nH2, L2).permute(1, 0)

    # bicubic interpolate absolute_pos_embed if not match
    absolute_pos_embed_keys = [
        k for k in state_dict.keys() if 'absolute_pos_embed' in k
    ]
    for k in absolute_pos_embed_keys:
        # dpe
        absolute_pos_embed_pretrained = state_dict[k]
        absolute_pos_embed_current = model.state_dict()[k]
        _, L1, C1 = absolute_pos_embed_pretrained.size()
        _, L2, C2 = absolute_pos_embed_current.size()
        if C1 != C1:
            logger.warning(f'Error in loading {k}, passing......')
        else:
            if L1 != L2:
                S1 = int(L1**0.5)
                S2 = int(L2**0.5)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(
                    -1, S1, S1, C1)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(
                    0, 3, 1, 2)
                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                    absolute_pos_embed_pretrained,
                    size=(S2, S2),
                    mode='bicubic')
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(
                    0, 2, 3, 1)
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(
                    1, 2)
                state_dict[k] = absolute_pos_embed_pretrained_resized

    # check classifier, if not match, then re-init classifier to zero
    if 'head.bias' in state_dict:
        # one layer head
        head_key = 'head'
        model_head_key = 'head'
    elif 'head.2.bias' in state_dict:
        # mlp head
        head_key = 'head.2'
        model_head_key = 'head[2]'
    else:
        head_key = None
        model_head_key = None
    if head_key is not None:
        head_bias_pretrained = state_dict[f'{head_key}.bias']
        Nc1 = head_bias_pretrained.shape[0]
        Nc2 = eval(f'model.{model_head_key}.bias.shape[0]')

        if (Nc1 != Nc2):
            if config.TRAIN.RAND_INIT_FT_HEAD:
                eval(f'model.{model_head_key}').weight.data = eval(f'model.{model_head_key}').weight.data * 0.001
                eval(f'model.{model_head_key}').bias.data = eval(f'model.{model_head_key}').bias.data * 0.001
                del state_dict[f'{head_key}.weight']
                del state_dict[f'{head_key}.bias']
                logger.warning(
                    f'Error in loading classifier head, re-init classifier head to 0'
                )
            elif Nc1 == 21841 and Nc2 == 1000:
                logger.info(
                    'loading ImageNet-22K weight to ImageNet-1K ......')
                map22kto1k_path = 'meta_data/map22kto1k.txt'
                logger.info(map22kto1k_path)
                with open(map22kto1k_path) as f:
                    map22kto1k = f.readlines()
                map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
                state_dict[f'{head_key}.weight'] = state_dict[f'{head_key}.weight'][
                    map22kto1k, :]
                state_dict[f'{head_key}.bias'] = state_dict[f'{head_key}.bias'][map22kto1k]
            elif Nc1 == 47338 and Nc2 == 1000:
                logger.info('loading Bamboo-47K weight to ImageNet-1K ......')
                map22kto1k_path = 'meta_data/map22kto1k.txt'
                logger.info(map22kto1k_path)
                with open(map22kto1k_path) as f:
                    map22kto1k = f.readlines()
                map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
                state_dict[f'{head_key}.weight'] = state_dict[f'{head_key}.weight'][
                    map22kto1k, :]
                state_dict[f'{head_key}.bias'] = state_dict[f'{head_key}.bias'][map22kto1k]

    msg = model.load_state_dict(state_dict, strict=False)
    logger.warning(msg)

    # from IPython import embed
    # embed()

    logger.info(f'=> loaded successfully {config.MODEL.PRETRAINED}')

    del checkpoint
    torch.cuda.empty_cache()


def convert_22k_head_to_1k(model, logger):
    head_weight = model.module.head.weight
    head_bias = model.module.head.bias
    Nc1 = head_bias.shape[0]

    if Nc1 == 21841:
        logger.info('converting ImageNet-22K head to ImageNet-1K ......')
        map22kto1k_path = 'meta_data/map22kto1k.txt'
        logger.info(map22kto1k_path)
        with open(map22kto1k_path) as f:
            map22kto1k = f.readlines()
        map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
        model.module.head.weight = torch.nn.Parameter(
            head_weight[map22kto1k, :])
        model.module.head.bias = torch.nn.Parameter(head_bias[map22kto1k])
    else:
        logger.warning(f'Error in converting classifier head')

    return model


def save_checkpoint(config,
                    epoch,
                    model,
                    max_accuracy,
                    optimizer,
                    lr_scheduler,
                    scaler,
                    logger,
                    model_ema=None,
                    max_accuracy_ema=None,
                    ema_decay=None,
                    model_ems=None,
                    max_accuracy_ems=None,
                    ems_model_num=None,
                    best=None):

    save_state = {
        'state_dict': model.state_dict(), 
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'max_accuracy': max_accuracy,
        'epoch': epoch,
        'config': config
    }
    if model_ema is not None:
        save_state['model_ema'] = get_state_dict(model_ema)
    if max_accuracy_ema is not None:
        save_state['max_accuracy_ema'] = max_accuracy_ema
    if ema_decay is not None:
        save_state['ema_decay'] = ema_decay
    if model_ems is not None:
        save_state['model_ems'] = get_state_dict(model_ems)
    if max_accuracy_ems is not None:
        save_state['max_accuracy_ems'] = max_accuracy_ems
    if ems_model_num is not None:
        save_state['ems_model_num'] = ems_model_num
    if config.AMP_OPT_LEVEL != 'O0':
        # save_state['amp'] = amp.state_dict()
        save_state['amp'] = scaler.state_dict()
    if best is None:
        save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
    else:
        save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{best}.pth')
    logger.info(f'{save_path} saving......')
    torch.save(save_state, save_path)
    logger.info(f'{save_path} saved !!!')

    if dist.get_rank() == 0 and isinstance(epoch, int):
        to_del = epoch - config.SAVE_CKPT_NUM * config.SAVE_FREQ
        old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{to_del}.pth')
        if os.path.exists(old_ckpt):
            os.remove(old_ckpt)


def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item()**norm_type
    total_norm = total_norm**(1. / norm_type)
    return total_norm


def auto_resume_helper(output_dir):
    checkpoints = os.listdir(output_dir)
    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
    print(f'All checkpoints founded in {output_dir}: {checkpoints}')
    if len(checkpoints) > 0:
        latest_checkpoint = max(
            [os.path.join(output_dir, d) for d in checkpoints],
            key=os.path.getmtime)
        print(f'The latest checkpoint founded: {latest_checkpoint}')
        resume_file = latest_checkpoint
    else:
        resume_file = None
    return resume_file


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


# https://github.com/facebookresearch/ConvNeXt/blob/main/utils.py
class NativeScalerWithGradNormCount:
    state_dict_key = 'amp_scaler'

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self,
                 loss,
                 optimizer,
                 clip_grad=None,
                 parameters=None,
                 create_graph=False,
                 update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(
                    optimizer
                )  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


class MyAverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self, name=None, max_len=-1):
        self.val_list = []
        self.count = []
        self.avg_name = name
        self.max_len = max_len
        self.val = 0
        self.avg = 0
        self.var = 0

    def update(self, val):
        self.val = val
        self.avg = 0
        self.var = 0
        if not math.isnan(val) and not math.isinf(val):
            self.val_list.append(val)
        # else:
        #     print(f'Nan in {self.avg_name}')
        if self.max_len > 0 and len(self.val_list) > self.max_len:
            self.val_list = self.val_list[-self.max_len:]
        if len(self.val_list) > 0:
            self.avg = np.mean(np.array(self.val_list))
            self.var = np.std(np.array(self.val_list))
