# # Copyright (c) Meta Platforms, Inc. and affiliates.
# # All rights reserved.

# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.
# # --------------------------------------------------------
# # References:
# # DeiT: https://github.com/facebookresearch/deit
# # BEiT: https://github.com/microsoft/unilm/tree/master/beit
# # --------------------------------------------------------

# import builtins
# import datetime
# import os
# import sys
# import time
# from collections import defaultdict, deque
# from pathlib import Path

# import torch
# import torch.distributed as dist
# import wandb
# from torch._six import inf


# class SmoothedValue(object):
#     """Track a series of values and provide access to smoothed values over a
#     window or the global series average.
#     """

#     def __init__(self, window_size=20, fmt=None):
#         if fmt is None:
#             fmt = "{median:.4f} ({global_avg:.4f})"
#         self.deque = deque(maxlen=window_size)
#         self.total = 0.0
#         self.count = 0
#         self.fmt = fmt

#     def update(self, value, n=1):
#         self.deque.append(value)
#         self.count += n
#         self.total += value * n

#     def synchronize_between_processes(self):
#         """
#         Warning: does not synchronize the deque!
#         """
#         if not is_dist_avail_and_initialized():
#             return
#         t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
#         dist.barrier()
#         dist.all_reduce(t)
#         t = t.tolist()
#         self.count = int(t[0])
#         self.total = t[1]

#     @property
#     def median(self):
#         d = torch.tensor(list(self.deque))
#         return d.median().item()

#     @property
#     def avg(self):
#         d = torch.tensor(list(self.deque), dtype=torch.float32)
#         return d.mean().item()

#     @property
#     def global_avg(self):
#         if self.total == 0:
#             return 0
#         else:
#             return self.total / self.count

#     @property
#     def max(self):
#         return max(self.deque)

#     @property
#     def value(self):
#         return self.deque[-1]

#     def __str__(self):
#         return self.fmt.format(
#             median=self.median,
#             avg=self.avg,
#             global_avg=self.global_avg,
#             max=self.max,
#             value=self.value)

# class NativeScalerWithGradNormCount:
#     state_dict_key = "amp_scaler"

#     def __init__(self):
#         self._scaler = torch.cuda.amp.GradScaler()

#     def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
#         self._scaler.scale(loss).backward(create_graph=create_graph)
#         if update_grad:
#             if clip_grad is not None:
#                 assert parameters is not None
#                 self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
#                 norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
#             else:
#                 self._scaler.unscale_(optimizer)
#                 norm = get_grad_norm_(parameters)
#             self._scaler.step(optimizer)
#             self._scaler.update()
#         else:
#             norm = None
#         return norm

#     def state_dict(self):
#         return self._scaler.state_dict()

#     def load_state_dict(self, state_dict):
#         self._scaler.load_state_dict(state_dict)

# def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
#     if isinstance(parameters, torch.Tensor):
#         parameters = [parameters]
#     parameters = [p for p in parameters if p.grad is not None]
#     norm_type = float(norm_type)
#     if len(parameters) == 0:
#         return torch.tensor(0.)
#     device = parameters[0].grad.device
#     if norm_type == inf:
#         total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
#     else:
#         total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
#                                 norm_type)
#     return total_norm




# def create_output_dir(args):
#     # # pretrain dir
#     # pretrain_dir = args.output
#     # pretrain_dir += f'{args.setting}_{args.split}Split/'
#     # pretrain_dir += f'{args.label_ratio}Label/'
#     # pretrain_dir += f'{args.steps}Steps/'

#     # # record mechanism
#     # output_dir = pretrain_dir + f'BufferSize{args.size_replay_buffer}_LR{args.blr}'
#     # if args.sampling == 'batchmix':
#     #     output_dir += f'Sampling{args.sampling}{args.mem_sampling_rate}'
#     # if args.unsup_loss:
#     #     output_dir += f'_{args.batch_split}unsup_loss'

#     # if args.replay_first:
#     #     output_dir += "replay_first"
#     # if args.mask_cur_loss:
#     #     output_dir += "mask_cur_loss"
#     # if args.pretrained:
#     #     pretrain_dir += '_Pretrained/'

#     # # runs count
#     # if os.path.exists(output_dir):
#     #     dirs = [a for a in os.listdir(output_dir) if a[:3] == 'run']
#     #     if dirs != []:
#     #         lastrun = max([int(a[3:]) for a in dirs])
#     #         if args.resume and not args.resume_dir:
#     #             output_dir += f'/run{lastrun}'
#     #         elif args.resume and args.resume_dir:
#     #             output_dir = args.resume_dir
#     #         else:
#     #             output_dir += f'/run{lastrun + 1}'
#     # else:
#     #     output_dir += '/run0'



#     return output_dir


def fix_random_seed(seed: int):
    import random
    import torch
    import numpy

    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# def init_wandb_writer(args):
#     # if args.run_name == '':
#     #     id = f'{args.steps}steps'
#     # else:
#     #     id = args.run_name

#     wandb.login(key='')
#     wandb.init(project=args.project_name, entity=args.entity, name=args.run_name, reinit=True,
#                dir=args.output_dir)
#     wandb.config.update(args)



def find_latest_checkpoint(dir):
    files = os.listdir(dir)
    tasks = [int(f.split('.')[-2]) for f in files if f.endswith('.pth')]
    if len(tasks) == 0:
        return -1, 0
    last_task = max(tasks)
    print(f'find ckpt at task {last_task}')
    sys.stdout.flush()
    return last_task


def resume_ckpt(args, model, last_task='no'):
    out_dir = args.output_dir
    model_name = "checkpoint"

    if last_task == 'no':
        last_task = find_latest_checkpoint(out_dir)
    path = os.path.join(out_dir, f'{model_name}.{last_task}.pth')
    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(last_task))
        checkpoint = torch.load(path, map_location=torch.device('cuda:0'))
        model.load_state_dict(checkpoint['model'])
        print("=> loaded checkpoint '{}'"
              .format(last_task))
        del checkpoint
    else:
        print("=> no checkpoint found")
    return last_task


# def save_checkpoint(state, task, args):
#     if args.gpu != 0:
#         return

#     torch.save(state, f'{args.output_dir}/checkpoint.{task}.pth')
#     print('ckpt saved')


# def logging(x_name, x_value, y_name, y_value, args):
#     if args.gpu != 0:
#         return
#     if args.wandb_log:
#         wandb.define_metric(x_name)
#         wandb.define_metric(y_name, step_metric=x_name)
#         wandb.log({
#             x_name: x_value,
#             y_name: y_value
#         })


# def setup_for_distributed(is_master):
#     """
#     This function disables printing when not in master process
#     """
#     builtin_print = builtins.print

#     def print(*args, **kwargs):
#         force = kwargs.pop('force', False)
#         force = force or (get_world_size() > 8)
#         if is_master or force:
#             now = datetime.datetime.now().time()
#             builtin_print('[{}] '.format(now), end='')  # print with time stamp
#             builtin_print(*args, **kwargs)

#     builtins.print = print

# def is_dist_avail_and_initialized():
#     if not dist.is_available():
#         return False
#     if not dist.is_initialized():
#         return False
#     return True

# def get_world_size():
#     if not is_dist_avail_and_initialized():
#         return 1
#     return dist.get_world_size()

# def get_rank():
#     if not is_dist_avail_and_initialized():
#         return 0
#     return dist.get_rank()

# def is_main_process():
#     return get_rank() == 0

# def save_on_master(*args, **kwargs):
#     if is_main_process():
#         torch.save(*args, **kwargs)

# def all_reduce_mean(x):
#     world_size = get_world_size()
#     if world_size > 1:
#         x_reduce = torch.tensor(x).cuda()
#         dist.all_reduce(x_reduce)
#         x_reduce /= world_size
#         return x_reduce.item()
#     else:
#         return x
