# Copyright (c) Winci.
# Licensed under the Apache License, Version 2.0 (the "License");

import argparse
import logging
import os
import numpy as np
import torch
import torch.distributed as dist
import math

def setup_logging(log_file, level, include_host=False):
    if include_host:
        import socket
        hostname = socket.gethostname()
        formatter = logging.Formatter(
            f'%(asctime)s |  {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
    else:
        formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')

    logging.root.setLevel(level)
    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
    for logger in loggers:
        logger.setLevel(level)

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logging.root.addHandler(stream_handler)

    if log_file:
        file_handler = logging.FileHandler(filename=log_file)
        file_handler.setFormatter(formatter)
        logging.root.addHandler(file_handler)

# 从imagenet pretrain中同时load encoder和momentum encoder
def load_pretrained_clue(encoder, pretrained_path):
    if pretrained_path:
        if os.path.isfile(pretrained_path):
            logging.info(f"loading checkpoint '{pretrained_path}'")
            checkpoint = torch.load(pretrained_path, map_location="cpu")

            # rename ins pre-trained keys
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            for k in list(state_dict.keys()):
                # using momentum encoder for downstream tasks
                # 原方法使用的是momentum encoder做下游，但是moco明明就是用的encoder
                # 所以这里直接使用encoder
                # update: 2025-09-22 使用momentum_encoder做下游
                if k.startswith('module.encoder'):
                    # remove prefix
                    state_dict[k[len("module.encoder."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]
                if k.startswith('module.momentum_encoder'):
                    # remove prefix
                    state_dict[k[len("module.momentum_encoder."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]
                # using encoder for downstream tasks
                if k.startswith('module.net_vlad'):
                    # remove prefix
                    state_dict[k[len("module.net_vlad."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]


            msg = encoder.load_state_dict(state_dict, strict=False)
            logging.info(msg)

            if 'epoch' in checkpoint:
                logging.info(f"loaded pre-trained encoder (epoch {checkpoint['epoch']})")

            del checkpoint, state_dict
        else:
            logging.info(f"no checkpoint found at '{pretrained_path}'")


# 从imagenet pretrain中load encoder
def load_pretrained_im(clue, pretrained_path):
    if pretrained_path:
        if os.path.isfile(pretrained_path):
            logging.info(f"loading checkpoint '{pretrained_path}'")
            checkpoint = torch.load(pretrained_path, map_location="cpu")
            # print(clue)
            # print(checkpoint.keys())
            
            # rename ins pre-trained keys
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            for k in list(state_dict.keys()):
                # change the prefix of the keys
                state_dict["encoder."+k[:]] = state_dict[k] # 添加 encoder.前缀
                state_dict["momentum_encoder."+k[:]] = state_dict[k] # 添加 momentum_encoder.前缀，按道理这里可以不用加载
                del state_dict[k]

            msg = clue.load_state_dict(state_dict, strict=False)
            logging.info(msg)

            if 'epoch' in checkpoint:
                logging.info(f"loaded pre-trained encoder (epoch {checkpoint['epoch']})")

            del checkpoint, state_dict
        else:
            logging.info(f"no checkpoint found at '{pretrained_path}'")

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    # look for a checkpoint in exp repository
    if isinstance(ckp_paths, list):
        for ckp_path in ckp_paths:
            if os.path.isfile(ckp_path):
                break
    else:
        ckp_path = ckp_paths

    if not os.path.isfile(ckp_path):
        return

    logging.info(f"Found checkpoint at {ckp_path}")

    # open checkpoint file
    checkpoint = torch.load(
        ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count())
    )

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            try:
                msg = value.load_state_dict(checkpoint[key], strict=False)
                logging.info(msg)
            except TypeError:
                msg = value.load_state_dict(checkpoint[key])
            logging.info(f"=> loaded {key} from checkpoint '{ckp_path}'")
        else:
            logging.warning(f"=> failed to load {key} from checkpoint '{ckp_path}'")

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]

def build_optimizer(parameters, args):
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(parameters, 0, momentum=0.9, weight_decay=args.wd)
        args.final_lr = args.lr * 0.1   
        # args.final_lr = 0.0
        """
        we observe a slight performance drop in CLUE 
        when the learning rate decreases to a very small value during the later stages of training. 
        We hypothesize that an excessively small learning rate may amplify the regularization effect of weight decay in SGD, 
        causing the weights to diverge from the optimal solution. 
        To address this, we set the minimum learning rate in the cosine decay schedule to 0.1 ∗ lr.
        """

    elif args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(parameters, 0, weight_decay=args.wd)

    return optimizer

def world_info_from_env():
    local_rank = 0
    for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
        if v in os.environ:
            local_rank = int(os.environ[v])
            break
    global_rank = 0
    for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
        if v in os.environ:
            global_rank = int(os.environ[v])
            break
    world_size = 1
    for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
        if v in os.environ:
            world_size = int(os.environ[v])
            break

    return local_rank, global_rank, world_size

def init_distributed_device(args):
    # Distributed training = training on more than one GPU.
    # Works in both single and multi-node scenarios.
    if 'SLURM_PROCID' in os.environ:
        # DDP via SLURM
        args.local_rank, args.rank, args.world_size = world_info_from_env()
        # SLURM var -> torch.distributed vars in case needed
        os.environ['LOCAL_RANK'] = str(args.local_rank)
        os.environ['RANK'] = str(args.rank)
        os.environ['WORLD_SIZE'] = str(args.world_size)
        torch.distributed.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )
    else:
        # DDP via torchrun, torch.distributed.launch
        args.local_rank, _, _ = world_info_from_env()
        torch.distributed.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url)
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    args.distributed = True

    if torch.cuda.is_available():
        if not args.no_set_device_rank:
            device = 'cuda:%d' % args.local_rank
        else:
            device = 'cuda:0'
        torch.cuda.set_device(device)
    else:
        device = 'cpu'
    args.device = device
    device = torch.device(device)
    return device, args

class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res