import random

import numpy as np
import omegaconf
import torch
import wandb
import time
from collections import defaultdict, deque
import datetime


import torch
import torch.distributed as dist


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def to_torch(xs, device):
    return tuple(torch.as_tensor(x, device=device) for x in xs)


def setup_wandb(cfg):
    cfg_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    wandb.init(
        # entity=cfg.user.wandb_id,
        entity = 'emsa',
        project=cfg.wandb.project,
        settings=wandb.Settings(start_method="thread"),
        name=cfg.wandb.exp_name,
        #         reinit=True,
    )
    wandb.config.update(cfg_dict, allow_val_change=True)





class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """


    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt


    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n


    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]


    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()


    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()


    @property
    def global_avg(self):
        return self.total / self.count


    @property
    def max(self):
        return max(self.deque)


    @property
    def value(self):
        return self.deque[-1]


    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)

class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter


    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)


    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))


    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)


    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()


    def add_meter(self, name, meter):
        self.meters[name] = meter


    def log_every(self, iterable, print_freq, header=None, wandb=None, is_test=False, epoch=None):
        # print out the statistics every print_freq iterations
        # Initialize counter, time trackers, and output format
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}') # Time per iteration
        data_time = SmoothedValue(fmt='{avg:.4f}') # Time to load data
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',  # Iteration count
            'eta: {eta}',  # ETA for completion
            '{meters}',  # Metrics
            'time: {time}',  # Time per iteration
            'data: {data}'  # Data loading time
        ]
        if torch.cuda.is_available():  # Include memory usage if using GPU
            log_msg.append('max mem: {memory:.0f}')  # log_msg is a list
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:  # Iterate over data
            data_time.update(time.time() - end)  # Track data loading time
            yield obj  # Yield data to caller
            iter_time.update(time.time() - end)  # Track iteration time

            # Print progress every print_freq iterations
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)  # Calculate ETA
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():  # Print including memory usage if using GPU
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:  # Print without memory usage
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))

                # if wandb is not None:
                #     if is_test:
                #         pass
                #     else:
                #         wandb.log({"iteration":i,
                #                 #    "lr":self.meters['lr'].value,
                #                    "loss_avg":self.meters['loss'].avg,
                #                    "loss_median":self.meters['loss'].median,
                #                    "loss_global_avg":self.meters['loss'].global_avg,
                #                    "loss_max":self.meters['loss'].max,
                #                    "epoch":epoch})
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / (len(iterable)+1)))

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print


    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)


    __builtin__.print = print




def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True




def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()




def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()




def is_main_process():
    return get_rank() == 0




def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)




def init_distributed_mode(args):
    # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
    #     args.rank = int(os.environ["RANK"])
    #     args.world_size = int(os.environ['WORLD_SIZE'])
    #     args.gpu = int(os.environ['LOCAL_RANK'])
    # elif 'SLURM_PROCID' in os.environ:
    #     args.rank = int(os.environ['SLURM_PROCID'])
    #     args.gpu = args.rank % torch.cuda.device_count()
    # else:
    print('Not using distributed mode')
    args.distributed = False
    return


    args.distributed = True


    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)



