import io
import os
import time
from collections import defaultdict, deque
import datetime

import torch
import torch.distributed as dist

import sys
from torch import optim as optim
from typing import List, Union
import json


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / max(1, self.count)

    @property
    def max(self):
        return 0 if len(self.deque)==0 else max(self.deque)

    @property
    def value(self):
        return 0 if len(self.deque)==0 else self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                    sys.stdout.flush()
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / max(1, len(iterable))))
        sys.stdout.flush()


def _load_checkpoint_for_ema(model_ema, checkpoint):
    """
    Workaround for ModelEma._load_checkpoint to accept an already-loaded object
    """
    mem_file = io.BytesIO()
    torch.save(checkpoint, mem_file)
    mem_file.seek(0)
    model_ema._load_checkpoint(mem_file)


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if 'OMPI_COMM_WORLD_RANK' in os.environ:
        args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
        args.world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE'))
        args.gpu = args.rank % torch.cuda.device_count()
    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)


def build_optimizer(config, model, momentum=0.9):
    """
    Build optimizer, set weight decay of normalization to 0 by default.
    """
    skip = {}
    skip_keywords = {}
    if hasattr(model, 'low_weight_decay_keywords'):
        low_keywords = model.low_weight_decay_keywords()
    if hasattr(model, 'no_weight_decay'):
        skip = model.no_weight_decay()
    if hasattr(model, 'no_weight_decay_keywords'):
        skip_keywords = model.no_weight_decay_keywords()
    parameters = set_weight_decay(model, skip_list=skip, skip_keywords=skip_keywords, low_keywords=low_keywords)

    opt_lower = config.opt.lower()
    optimizer = None
    if opt_lower == 'sgd':
        optimizer = optim.SGD(parameters, momentum=float(momentum), nesterov=True,
                              lr=config.lr, weight_decay=config.weight_decay)
    elif opt_lower == 'adamw':
        optimizer = AdamWCustomized(parameters, eps=config.opt_eps,
                                lr=config.lr, weight_decay=config.weight_decay)
    else:
        raise NotImplementedError

    return optimizer


def set_weight_decay(model, skip_list=(), skip_keywords=(), low_keywords=()):
    has_decay = []
    has_decay_name = []
    no_decay = []
    no_decay_name = []
    low_decay = []
    low_decay_name = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
                check_keywords_in_name(name, skip_keywords):

            no_decay.append(param)
            no_decay_name.append(name)
        elif check_keywords_in_name(name, low_keywords):
            low_decay.append(param)
            low_decay_name.append(name)
        else:
            has_decay.append(param)
            has_decay_name.append(name)

    return [{'params': has_decay},
            {'params': low_decay, 'weight_decay': 0.0001},
            {'params': no_decay, 'weight_decay': 0.}]


def check_keywords_in_name(name, keywords=()):
    isin = False
    for keyword in keywords:
        if keyword in name:
            isin = True
    return isin


def create_2optimizers(args, model, filter_bias_and_bn=True):
    if filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        parameters1, parameters2 = add_weight_decay_2ops(model, args, skip)
    else:
        raise NotImplementedError

    opt_args1 = dict()
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args1['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args1['betas'] = args.opt_betas

    opt_args2 = dict()
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args2['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args2['betas'] = args.opt_betas

    if args.opt1 == 'adamw':
        optimizer1 = optim.AdamW(parameters1, lr=args.lr1, **opt_args1)
    elif args.opt1 == 'sgd':
        optimizer1 = optim.SGD(parameters1, lr=args.lr1, momentum=0.9, nesterov=True)
    else:
        raise NotImplementedError

    if args.opt2 == 'adamw':
        optimizer2 = optim.AdamW(parameters2, lr=args.lr2, **opt_args1)
    elif args.opt2 == 'sgd':
        optimizer2 = optim.SGD(parameters2, lr=args.lr2, momentum=0.9, nesterov=True)
    else:
        raise NotImplementedError

    return optimizer1, optimizer2


def add_weight_decay_2ops(model, args, skip_list=()):
    lr1_decay = []
    lr1_no_decay = []
    lr2 = []

    lr1_decay_n = []
    lr1_no_decay_n = []
    lr2_n = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if name in model.lr2_params():
            lr2.append(param)
            lr2_n.append(name)
        elif len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            lr1_no_decay.append(param)
            lr1_no_decay_n.append(name)
        else:
            lr1_decay.append(param)
            lr1_decay_n.append(name)

    # print('hey')
    return [
               {'params': lr1_no_decay, 'weight_decay': 0.},
               {'params': lr1_decay, 'weight_decay': args.weight_decay1}], [
               {'params': lr2, 'weight_decay': args.weight_decay2}]


def weight_list2_flags(weight_list):

    flags = [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 0]]

    for name in weight_list:
        attr = name.split('.')

        if len(attr) != 5:
            continue

        if 'q' == attr[3]:
            flags[int(attr[1])][0] = 1
        elif 'k' == attr[3]:
            flags[int(attr[1])][1] = 1
        elif 'v' == attr[3]:
            flags[int(attr[1])][2] = 1
        elif 'proj' == attr[3]:
            flags[int(attr[1])][3] = 1
        elif 'fc1' == attr[3]:
            flags[int(attr[1])][4] = 1
        elif 'fc2' == attr[3]:
            flags[int(attr[1])][5] = 1
        else:
            raise NotImplementedError

    return flags


def flags2weight_list(flags):

    weight_list = []
    for i, block_flags in enumerate(flags):

        for j, flag in enumerate(block_flags):
            if flag and j == 0:
                weight_list.append(f'blocks.{i}.attn.IECG_q_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_q_b.weight')
            elif flag and j == 1:
                weight_list.append(f'blocks.{i}.attn.IECG_k_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_k_b.weight')
            elif flag and j == 2:
                weight_list.append(f'blocks.{i}.attn.IECG_v_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_v_b.weight')
            elif flag and j == 3:
                weight_list.append(f'blocks.{i}.attn.IECG_proj_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_proj_b.weight')
            elif flag and j == 4:
                weight_list.append(f'blocks.{i}.attn.IECG_fc1_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_fc1_b.weight')
            elif flag and j == 5:
                weight_list.append(f'blocks.{i}.attn.IECG_fc2_a.weight')
                weight_list.append(f'blocks.{i}.attn.IECG_fc2_b.weight')

    return weight_list


class AdamWCustomized(optim.AdamW):
    def fake_step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            state_sums = []
            max_exp_avg_sqs = []
            state_steps = []
            amsgrad = group['amsgrad']

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avgs.append(state['exp_avg'])
                exp_avg_sqs.append(state['exp_avg_sq'])

                if amsgrad:
                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                beta1, beta2 = group['betas']
                # update the steps for each param group update
                state['step'] += 1
                # record the step after step update
                state_steps.append(state['step'])


def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)#/len(kernel_val)


def mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    # batch_size = int(source.size()[0])
    # kernels = guassian_kernel(source, target,
    #                           kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    # XX = kernels[:batch_size, :batch_size]
    # YY = kernels[batch_size:, batch_size:]
    # XY = kernels[:batch_size, batch_size:]
    # YX = kernels[batch_size:, :batch_size]
    # loss = torch.mean(XX + YY - XY -YX)

    n = int(source.size()[0])
    m = int(target.size()[0])

    kernels = guassian_kernel(source, target,
                              kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:n, :n]
    YY = kernels[n:, n:]
    XY = kernels[:n, n:]
    YX = kernels[n:, :n]

    XX = torch.div(XX, n * n).sum(dim=1).view(1, -1)  # K_ss矩阵，Source<->Source
    XY = torch.div(XY, -n * m).sum(dim=1).view(1, -1)  # K_st矩阵，Source<->Target

    YX = torch.div(YX, -m * n).sum(dim=1).view(1, -1)  # K_ts矩阵,Target<->Source
    YY = torch.div(YY, m * m).sum(dim=1).view(1, -1)  # K_tt矩阵,Target<->Target

    loss = (XX + XY).sum() + (YX + YY).sum()
    return loss


def mmd_linear(f_of_X, f_of_Y, margin=0):
    delta = max((f_of_X - f_of_Y - margin).abs(), 0)
    loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
    return loss


def read_json(filename: str) -> Union[list, dict]:
    """read json files"""
    with open(filename, "rb") as fin:
        data = json.load(fin, encoding="utf-8")
    return data


def CORAL(source, target, device="cuda"):
    d = source.size(1)
    ns, nt = source.size(0), target.size(0)

    # source covariance
    tmp_s = torch.ones((1, ns)).to(device) @ source
    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

    # target covariance
    tmp_t = torch.ones((1, nt)).to(device) @ target
    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

    # frobenius norm
    loss = (cs - ct).pow(2).sum().sqrt()
    loss = loss / (4 * d * d)

    return loss