# -*- coding: UTF-8 -*-
import torch
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F
from typing import Union
import os.path as osp
import os
import time
from torchvision.utils import save_image
import torch.distributed as dist
import math
import inspect
from torch._six import string_classes
import collections.abc as container_abcs
import warnings
from utils.ops import load_network


def get_varname(var):
    """
    Gets the name of var. Does it from the out most frame inner-wards.
    :param var: variable to get name from.
    :return: string
    """
    for fi in reversed(inspect.stack()):
        names = [var_name for var_name, var_val in fi.frame.f_locals.items() if var_val is var]
        if len(names) > 0:
            return names[0]


def reduce_tensor(rt):
    rt = rt.clone()
    if dist.is_initialized():
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        world_size = dist.get_world_size()
    else:
        world_size = 1
    rt /= world_size
    return rt


class LoggerX(object):

    def __init__(self, save_root, enable_wandb=False, **kwargs):
        assert dist.is_initialized()
        self.models_save_dir = osp.join(save_root, 'save_models')
        self.images_save_dir = osp.join(save_root, 'save_images')
        os.makedirs(self.models_save_dir, exist_ok=True)
        os.makedirs(self.images_save_dir, exist_ok=True)
        self._modules = []

        self._module_names = []
        self.world_size = dist.get_world_size()
        self.local_rank = dist.get_rank()
        self.enable_wandb = enable_wandb
        if enable_wandb and self.local_rank == 0:
            import wandb
            #print('======',save_root,wandb.Settings(_disable_stats=True, _disable_meta=True))
            wandb.init(dir=save_root, settings=wandb.Settings(_disable_stats=True, _disable_meta=True), **kwargs)

    @property
    def modules(self):
        return self._modules

    @property
    def module_names(self):
        return self._module_names

    @modules.setter
    def modules(self, modules):
        for i in range(len(modules)):
            self._modules.append(modules[i])
            self._module_names.append(get_varname(modules[i]))

    def append(self, module, name=None):
        self._modules.append(module)
        if name is None:
            name = get_varname(module)
        self._module_names.append(name)

    def checkpoints(self, epoch):
        if self.local_rank != 0:
            return
        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            module = self.modules[i]
            torch.save(module.state_dict(), osp.join(self.models_save_dir, '{}-{}'.format(module_name, epoch)))

    def load_checkpoints(self, epoch):
        byol=True
        if byol:

            for i in range(len(self.module_names)):
                module_name = self.module_names[i]
                module = self.modules[i]
                module.load_state_dict(load_network(osp.join(self.models_save_dir, '{}-{}'.format(module_name, epoch))))
        else:
            for i in range(1):
            # for i in range(len(self.module_names)):
                module_name = self.module_names[i]
                # print(module_name)
                module = self.modules[i]
                state_dict = torch.load(osp.join(self.models_save_dir, '{}-{}'.format(module_name, epoch)), map_location='cpu')
                state_dict = state_dict['model']
                cpk_q, cpk_k = load_model(state_dict,name='cdc')
                module.load_state_dict(cpk_q, strict=False)
                module.load_state_dict(cpk_k, strict=False)

    def msg(self, stats, step):
        output_str = '[{}] {:05d}, '.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), step)
        output_dict = {}
        for i in range(len(stats)):
            if isinstance(stats, (list, tuple)):
                var = stats[i]
                var_name = get_varname(stats[i])
            elif isinstance(stats, dict):
                var_name, var = list(stats.items())[i]
            else:
                raise NotImplementedError
            if isinstance(var, torch.Tensor):
                var = var.detach().mean()
                var = reduce_tensor(var)
                var = var.item()
            output_str += '{} {:2.5f}, '.format(var_name, var)
            output_dict[var_name] = var

        if self.enable_wandb and self.local_rank == 0:
            import wandb
            wandb.log(output_dict, step)

        if self.local_rank == 0:
            print(output_str)

    def msg_str(self, output_str):
        if self.local_rank == 0:
            print(str(output_str))

    def save_image(self, grid_img, n_iter, sample_type):
        save_image(grid_img, osp.join(self.images_save_dir,
                                      '{}_{}_{}.jpg'.format(n_iter, self.local_rank, sample_type)),
                   nrow=1)
def load_model(state_dict,name):
    cpk_q = {}
    cpk_k = {}
    if name=='cdc':


        # print(state_dict.keys())
        import torch.nn as nn
        # module.backbone.conv1 = nn.Conv2d(
        #     3, 64, kernel_size=3, stride=1, padding=1, bias=False
        # )


        # Modify the keys for encoder_q (q-model)
        for k, v in state_dict.items():
            k_q = k.replace('module.', '') \
                .replace('backbone.conv1', 'encoder_q.0') \
                .replace('backbone.bn1', 'encoder_q.1') \
                .replace('backbone.layer1', 'encoder_q.3') \
                .replace('backbone.layer2', 'encoder_q.4') \
                .replace('backbone.layer3', 'encoder_q.5') \
                .replace('backbone.layer4', 'encoder_q.6') \
                .replace('cluster_head.0', 'projector_q') \
                # .replace('cluster_head_i', 'projector_q')

            # print(k_q, ':', v.shape)
            cpk_q[k_q] = v


        # Modify the keys for encoder_k (k-model)
        for k, v in state_dict.items():
            k_k = k.replace('module.', '') \
                .replace('backbone.conv1', 'encoder_k.0') \
                .replace('backbone.bn1', 'encoder_k.1') \
                .replace('backbone.layer1', 'encoder_k.3') \
                .replace('backbone.layer2', 'encoder_k.4') \
                .replace('backbone.layer3', 'encoder_k.5') \
                .replace('backbone.layer4', 'encoder_k.6') \
                .replace('cluster_head.0', 'projector_k') \
                # .replace('cluster_head_i', 'projector_k')
                # .replace('shortcut', 'downsample')
            cpk_k[k_k] = v
    elif name=='cc':
        # state_dict = state_dict['model']
        for k, v in state_dict.items():
            k_q = k.replace('module.', '') \
                .replace('backbone.conv1', 'encoder_q.0') \
                .replace('backbone.bn1', 'encoder_q.1') \
                .replace('backbone.layer1', 'encoder_q.3') \
                .replace('backbone.layer2', 'encoder_q.4') \
                .replace('backbone.layer3', 'encoder_q.5') \
                .replace('backbone.layer4', 'encoder_q.6') \
                .replace('cluster_projector.0', 'projector_q.0')\
                .replace('cluster_projector.2', 'projector_q.2') \
                # .replace('instance_projector.3', 'projector_q.3') \
                # .replace('instance_projector.5', 'projector_q.5') \
                # .replace('instance_projector.4', 'projector_q.5') \
                # .replace('instance_projector.5', 'projector_q.6')
            cpk_q[k_q] = v
            print(k_q, ':', v.shape)

        # print(cpk_q.keys())
        for k, v in state_dict.items():
            k_k = k.replace('module.', '') \
                .replace('backbone.conv1', 'encoder_k.0') \
                .replace('backbone.bn1', 'encoder_k.1') \
                .replace('backbone.layer1', 'encoder_k.3') \
                .replace('backbone.layer2', 'encoder_k.4') \
                .replace('backbone.layer3', 'encoder_k.5') \
                .replace('backbone.layer4', 'encoder_k.6') \
                .replace('cluster_projector.0', 'projector_k.0') \
                .replace('cluster_projector.2', 'projector_k.2') \
                # .replace('instance_projector.3', 'projector_k.3') \
                # .replace('instance_projector.5', 'projector_k.5') \
                # .replace('resnet.conv1', 'encoder_k.0') \
                # .replace('resnet.bn1', 'encoder_k.1') \
                # .replace('resnet.layer1', 'encoder_k.4') \
                # .replace('resnet.layer2', 'encoder_k.5') \
                # .replace('resnet.layer3', 'encoder_k.6') \
                # .replace('resnet.layer4', 'encoder_k.7') \
                # .replace('cluster_projector.0', 'projector_k.0') \
                # .replace('cluster_projector.2', 'projector_k.1') \
                # .replace('resnet.fc', 'projector_k.0') \
                # .replace('instance_projector.0', 'projector_k.1') \
                # .replace('instance_projector.2', 'projector_k.3') \
                # .replace('instance_projector.3', 'projector_k.4') \
                # .replace('instance_projector.5', 'projector_k.6') \
                # .replace('cluster_projector.2', 'projector_k.3') \
                # .replace('cluster_projector.3', 'projector_k.4') \
                # .replace('instance_projector.4', 'projector_k.5') \
                # .replace('instance_projector.5', 'projector_k.6')
            cpk_k[k_k] = v
    elif name=='tac':
        # state_dict = state_dict['model']
        for k, v in state_dict.items():
            k_q = k.replace('module.', '') \
                .replace('resnet.conv1', 'encoder_q.0') \
                .replace('resnet.bn1', 'encoder_q.1') \
                .replace('resnet.layer1', 'encoder_q.4') \
                .replace('resnet.layer2', 'encoder_q.5') \
                .replace('resnet.layer3', 'encoder_q.6') \
                .replace('resnet.layer4', 'encoder_q.7') \
                .replace('resnet.fc', 'projector_q.0') \
                .replace('cluster_head_image', 'projector_q') \
                .replace('cluster_head_i', 'projector_q')

                # .replace('cluster_head_image.1', 'projector_q.1') \
                # .replace('instance_projector.3', 'projector_q.3') \
                # .replace('instance_projector.5', 'projector_q.6') \
                # .replace('instance_projector.4', 'projector_q.5') \
                # .replace('instance_projector.5', 'projector_q.6')
            cpk_q[k_q] = v
            # print(k_q, ':', v.shape)

        # print(cpk_q.keys())
        for k, v in state_dict.items():
            k_k = k.replace('module.', '') \
                .replace('resnet.conv1', 'encoder_k.0') \
                .replace('resnet.bn1', 'encoder_k.1') \
                .replace('resnet.layer1', 'encoder_k.4') \
                .replace('resnet.layer2', 'encoder_k.5') \
                .replace('resnet.layer3', 'encoder_k.6') \
                .replace('resnet.layer4', 'encoder_k.7') \
                .replace('cluster_head.0', 'projector_k') \
                .replace('resnet.fc', 'projector_k.0') \
                .replace('cluster_head_image', 'projector_k') \
                .replace('cluster_head_i', 'projector_k')
                # .replace('instance_projector.2', 'projector_k.3') \
                # .replace('instance_projector.3', 'projector_k.4') \
                # .replace('instance_projector.5', 'projector_k.6') \
                # .replace('cluster_projector.2', 'projector_k.3') \
                # .replace('cluster_projector.3', 'projector_k.4') \
                # .replace('instance_projector.4', 'projector_k.5') \
                # .replace('instance_projector.5', 'projector_k.6')
            cpk_k[k_k] = v
    return cpk_q,cpk_k