# --------------------------------------------------------
# SimMIM
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# Modified by Zhenda Xie
# --------------------------------------------------------

import os
import torch
import torch.distributed as dist
import numpy as np
from scipy import interpolate

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
    logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........")
    if config.MODEL.RESUME.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(
            config.MODEL.RESUME, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    logger.info(msg)
    max_accuracy = 0.0
    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        config.defrost()
        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
        config.freeze()
        if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
            amp.load_state_dict(checkpoint['amp'])
        logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
        if 'max_accuracy' in checkpoint:
            max_accuracy = checkpoint['max_accuracy']

    del checkpoint
    torch.cuda.empty_cache()
    return max_accuracy


def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, tag=None):
    save_state = {'model': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'lr_scheduler': lr_scheduler.state_dict(),
                  'max_accuracy': max_accuracy,
                  'epoch': epoch,
                  'config': config}
    if config.AMP_OPT_LEVEL != "O0":
        save_state['amp'] = amp.state_dict()

    if tag:
        save_path = os.path.join(config.OUTPUT, f'{tag}_ckpt_epoch_{epoch}.pth')
    else:
        save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
    logger.info(f"{save_path} saving......")
    try:
        torch.save(save_state, save_path)
        logger.info(f"{save_path} saved !!!")
    except:
        print('save is not available now')


def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm


def auto_resume_helper(output_dir, logger):
    checkpoints = os.listdir(output_dir)
    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
    logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
    if len(checkpoints) > 0:
        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
        logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
        resume_file = latest_checkpoint
    else:
        resume_file = None
    return resume_file


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


def load_comm_pretrained(config, model, logger, comm_adap_init=False):
    logger.info(f">>>>>>>>>> Fine-tuned from {config.PRETRAINED} ..........")
    checkpoint = torch.load(config.PRETRAINED, map_location='cpu')
    try:
        checkpoint_model = checkpoint['state_dict']
    except:
        checkpoint_model = checkpoint['model']
        
    if any([True if 'module.encoder.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('module.encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.encoder.')}
        logger.info('Detect pre-trained model, remove [module.encoder.] prefix.')
    elif any([True if 'module.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('module.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.')  and 'head' not in k}
        logger.info('Detect pre-trained model, remove [module.] prefix.')
    elif any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.') and 'head' not in k}
        logger.info('Detect pre-trained model, remove [encoder.] prefix.')    
        logger.info('Detect non-pre-trained model, pass without doing anything.')
    else:
        checkpoint_model = {k: v for k, v in checkpoint_model.items() if 'head' not in k}
        logger.info('Detect pre-trained model, no remove prefix.')    
        # logger.info('Detect non-pre-trained model, pass without doing anything.')
    
    
    if config.MODEL.TYPE == 'swin':
        logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
        checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)
    elif config.MODEL.TYPE == 'vit':
        logger.info(f">>>>>>>>>> Remapping pre-trained keys for VIT ..........")
        checkpoint = remap_pretrained_keys_vit(model, checkpoint_model, logger)
    else:
        pass
        # raise NotImplementedError

    msg = model.load_state_dict(checkpoint_model, strict=False)
    logger.info(msg)
    
    del checkpoint
    torch.cuda.empty_cache()
    logger.info(f">>>>>>>>>> loaded successfully '{config.PRETRAINED}'")
    
    

def load_pretrained(config, model, logger):
    logger.info(f">>>>>>>>>> Fine-tuned from {config.PRETRAINED} ..........")
    checkpoint = torch.load(config.PRETRAINED, map_location='cpu')
    # import pdb; pdb.set_trace()
    
    try:
        checkpoint_model = checkpoint['state_dict']
    except:
        checkpoint_model = checkpoint['model']
    
    
    # import pdb; pdb.set_trace()
    
    if any([True if 'module.encoder.0.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('module.encoder.0.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.encoder.0.') and 'head' not in k}
        logger.info('Detect pre-trained model, remove [module.encoder.0.] prefix.')
    elif any([True if 'module.encoder.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('module.encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.encoder.') and 'head' not in k}
        logger.info('Detect pre-trained model, remove [module.encoder.] prefix.')
    elif any([True if 'module.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('module.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.')  and 'head' not in k}
        
        # cm_keys = checkpoint_model.keys()
        # for k in cm_keys:
        #     if k.startswith('norm.'):
        checkpoint_model['fc_norm.weight'] = checkpoint_model.pop('norm.weight')
        checkpoint_model['fc_norm.bias'] = checkpoint_model.pop('norm.bias')
        
        # mstd = model.state_dict()
        # base_key = mstd.keys()
        # base_key_wo_gamm_wo_position = [aa for aa in base_key if not ('position' in aa or 'gamma' in aa)]
        # ckpt_key = [aa for aa in checkpoint_model.keys() if not 'decoder' in aa]
        # b_ks_dict = {k:mstd[k].shape for k in base_key_wo_gamm_wo_position}
        # c_ks_dict = {k:checkpoint_model[k].shape for k in ckpt_key}
        
        # import pdb; pdb.set_trace()
        
        logger.info('Detect pre-trained model, remove [module.] prefix.')
    elif any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.') and 'head' not in k}
        logger.info('Detect pre-trained model, remove [encoder.] prefix.')    
    else:
        checkpoint_model = {k: v for k, v in checkpoint_model.items() if 'head' not in k}
        logger.info('Detect pre-trained model, no remove prefix.')    
        # logger.info('Detect non-pre-trained model, pass without doing anything.')

    if config.MODEL.TYPE == 'swin':
        logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
        checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)
    elif config.MODEL.TYPE == 'vit':
        try:
            logger.info(f">>>>>>>>>> Remapping pre-trained keys for VIT ..........")
            checkpoint = remap_pretrained_keys_vit(model, checkpoint_model, logger)
        except:
            logger.info(f">>>>>>>>>> Remapping pre-trained keys for VIT [is Cancelled] ..........")
    else:
        pass
        # raise NotImplementedError

    msg = model.load_state_dict(checkpoint_model, strict=False)
    logger.info(msg)
    
    del checkpoint
    torch.cuda.empty_cache()
    logger.info(f">>>>>>>>>> loaded successfully '{config.PRETRAINED}'")


def remap_pretrained_keys_swin(model, checkpoint_model, logger):
    state_dict = model.state_dict()
    # import pdb; pdb.set_trace()
    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
    all_keys = list(checkpoint_model.keys())
    for key in all_keys:
        if "relative_position_bias_table" in key:
            try:
                relative_position_bias_table_pretrained = checkpoint_model[key]
                relative_position_bias_table_current = state_dict[key]
                L1, nH1 = relative_position_bias_table_pretrained.size()
                L2, nH2 = relative_position_bias_table_current.size()
                if nH1 != nH2:
                    logger.info(f"Error in loading {key}, passing......")
                else:
                    if L1 != L2:
                        logger.info(f"{key}: Interpolate relative_position_bias_table using geo.")
                        src_size = int(L1 ** 0.5)
                        dst_size = int(L2 ** 0.5)

                        def geometric_progression(a, r, n):
                            return a * (1.0 - r ** n) / (1.0 - r)

                        left, right = 1.01, 1.5
                        while right - left > 1e-6:
                            q = (left + right) / 2.0
                            gp = geometric_progression(1, q, src_size // 2)
                            if gp > dst_size // 2:
                                right = q
                            else:
                                left = q

                        # if q > 1.090307:
                        #     q = 1.090307

                        dis = []
                        cur = 1
                        for i in range(src_size // 2):
                            dis.append(cur)
                            cur += q ** (i + 1)

                        r_ids = [-_ for _ in reversed(dis)]

                        x = r_ids + [0] + dis
                        y = r_ids + [0] + dis

                        t = dst_size // 2.0
                        dx = np.arange(-t, t + 0.1, 1.0)
                        dy = np.arange(-t, t + 0.1, 1.0)

                        logger.info("Original positions = %s" % str(x))
                        logger.info("Target positions = %s" % str(dx))

                        all_rel_pos_bias = []

                        for i in range(nH1):
                            z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy()
                            f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
                            all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
                                relative_position_bias_table_pretrained.device))

                        new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
                        checkpoint_model[key] = new_rel_pos_bias
            except:
                continue
                    
    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k]
    for k in relative_position_index_keys:
        del checkpoint_model[k]

    # delete relative_coords_table since we always re-init it
    relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k]
    for k in relative_coords_table_keys:
        del checkpoint_model[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del checkpoint_model[k]

    return checkpoint_model


def remap_pretrained_keys_vit(model, checkpoint_model, logger):
    # Duplicate shared rel_pos_bias to each layer
    if getattr(model, 'use_rel_pos_bias', False) and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
        logger.info("Expand the shared relative position embedding to each transformer block.")
    num_layers = model.get_num_layers()
    rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
    for i in range(num_layers):
        checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
    checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
    
    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
    all_keys = list(checkpoint_model.keys())
    for key in all_keys:
        if "relative_position_index" in key:
            checkpoint_model.pop(key)

        if "relative_position_bias_table" in key:
            rel_pos_bias = checkpoint_model[key]
            src_num_pos, num_attn_heads = rel_pos_bias.size()
            dst_num_pos, _ = model.state_dict()[key].size()
            dst_patch_shape = model.patch_embed.patch_shape
            if dst_patch_shape[0] != dst_patch_shape[1]:
                raise NotImplementedError()
            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
            src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
            dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
            if src_size != dst_size:
                logger.info("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                def geometric_progression(a, r, n):
                    return a * (1.0 - r ** n) / (1.0 - r)

                left, right = 1.01, 1.5
                while right - left > 1e-6:
                    q = (left + right) / 2.0
                    gp = geometric_progression(1, q, src_size // 2)
                    if gp > dst_size // 2:
                        right = q
                    else:
                        left = q

                # if q > 1.090307:
                #     q = 1.090307

                dis = []
                cur = 1
                for i in range(src_size // 2):
                    dis.append(cur)
                    cur += q ** (i + 1)

                r_ids = [-_ for _ in reversed(dis)]

                x = r_ids + [0] + dis
                y = r_ids + [0] + dis

                t = dst_size // 2.0
                dx = np.arange(-t, t + 0.1, 1.0)
                dy = np.arange(-t, t + 0.1, 1.0)

                logger.info("Original positions = %s" % str(x))
                logger.info("Target positions = %s" % str(dx))

                all_rel_pos_bias = []

                for i in range(num_attn_heads):
                    z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
                    f = interpolate.interp2d(x, y, z, kind='cubic')
                    all_rel_pos_bias.append(
                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))

                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
                checkpoint_model[key] = new_rel_pos_bias
    
    return checkpoint_model