import copy
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
from collections import defaultdict, deque, OrderedDict

# class SmoothedValue:
#     """Track a series of values and provide access to smoothed values over a
#     window or the global series average.
#     """

#     def __init__(self, window_size=20, fmt=None):
#         if fmt is None:
#             fmt = "{median:.4f} ({global_avg:.4f})"
#         self.deque = deque(maxlen=window_size)
#         self.total = 0.0
#         self.count = 0
#         self.fmt = fmt

#     def update(self, value, n=1):
#         self.deque.append(value)
#         self.count += n
#         self.total += value * n

#     def synchronize_between_processes(self):
#         """
#         Warning: does not synchronize the deque!
#         """
#         t = reduce_across_processes([self.count, self.total])
#         t = t.tolist()
#         self.count = int(t[0])
#         self.total = t[1]

#     @property
#     def median(self):
#         d = torch.tensor(list(self.deque))
#         return d.median().item()

#     @property
#     def avg(self):
#         d = torch.tensor(list(self.deque), dtype=torch.float32)
#         return d.mean().item()

#     @property
#     def global_avg(self):
#         return self.total / self.count

#     @property
#     def max(self):
#         return max(self.deque)

#     @property
#     def value(self):
#         return self.deque[-1]

#     def __str__(self):
#         return self.fmt.format(
#             median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
#         )


# class MetricLogger:
#     def __init__(self, delimiter="\t"):
#         self.meters = defaultdict(SmoothedValue)
#         self.delimiter = delimiter

#     def update(self, **kwargs):
#         for k, v in kwargs.items():
#             if isinstance(v, torch.Tensor):
#                 v = v.item()
#             assert isinstance(v, (float, int))
#             self.meters[k].update(v)

#     def __getattr__(self, attr):
#         if attr in self.meters:
#             return self.meters[attr]
#         if attr in self.__dict__:
#             return self.__dict__[attr]
#         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

#     def __str__(self):
#         loss_str = []
#         for name, meter in self.meters.items():
#             loss_str.append(f"{name}: {str(meter)}")
#         return self.delimiter.join(loss_str)

#     def synchronize_between_processes(self):
#         for meter in self.meters.values():
#             meter.synchronize_between_processes()

#     def add_meter(self, name, meter):
#         self.meters[name] = meter

#     def log_every(self, iterable, print_freq, header=None):
#         i = 0
#         if not header:
#             header = ""
#         start_time = time.time()
#         end = time.time()
#         iter_time = SmoothedValue(fmt="{avg:.4f}")
#         data_time = SmoothedValue(fmt="{avg:.4f}")
#         space_fmt = ":" + str(len(str(len(iterable)))) + "d"
#         if torch.cuda.is_available():
#             log_msg = self.delimiter.join(
#                 [
#                     header,
#                     "[{0" + space_fmt + "}/{1}]",
#                     "eta: {eta}",
#                     "{meters}",
#                     "time: {time}",
#                     "data: {data}",
#                     "max mem: {memory:.0f}",
#                 ]
#             )
#         else:
#             log_msg = self.delimiter.join(
#                 [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
#             )
#         MB = 1024.0 * 1024.0
#         for obj in iterable:
#             data_time.update(time.time() - end)
#             yield obj
#             iter_time.update(time.time() - end)
#             if i % print_freq == 0:
#                 eta_seconds = iter_time.global_avg * (len(iterable) - i)
#                 eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
#                 if torch.cuda.is_available():
#                     print(
#                         log_msg.format(
#                             i,
#                             len(iterable),
#                             eta=eta_string,
#                             meters=str(self),
#                             time=str(iter_time),
#                             data=str(data_time),
#                             memory=torch.cuda.max_memory_allocated() / MB,
#                         )
#                     )
#                 else:
#                     print(
#                         log_msg.format(
#                             i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
#                         )
#                     )
#             i += 1
#             end = time.time()
#         total_time = time.time() - start_time
#         total_time_str = str(datetime.timedelta(seconds=int(total_time)))
#         print(f"{header} Total time: {total_time_str}")



def reduce_across_processes(val):
    if not is_dist_avail_and_initialized():
        # nothing to sync, but we still convert to tensor for consistency with the distributed case.
        return torch.tensor(val)

    t = torch.tensor(val, device="cuda")
    dist.barrier()
    dist.all_reduce(t)
    return t


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def init_distributed_mode(args):
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.gpu_counts = torch.cuda.device_count()
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    elif hasattr(args, "rank"):
        pass
    else:
        print("Not using distributed mode")
        args.distributed = False
        args.gpu =  0
        return


    torch.cuda.set_device(args.gpu)
    args.dist_backend = "nccl"
    print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
    torch.distributed.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)
    args.distributed = True

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    # __builtin__.print = print


