import pickle
import random
import shutil
import sys
from datetime import datetime
import os
import time
from collections import OrderedDict, defaultdict, deque

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import wandb
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union

__all__ = [
    "RandomSampler",
]

T_co = TypeVar('T_co', covariant=True)


class Sampler(Generic[T_co]):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source: Optional[Sized]) -> None:
        pass

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

class NoSampler_SubsetRandomSampler(Sampler[int]):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Args:
        indices (sequence): a sequence of indices
        generator (Generator): Generator used in sampling.
    """
    indices: Sequence[int]

    def __init__(self, indices: Sequence[int], generator=None) -> None:
        self.indices = indices
        self.generator = generator

    def __iter__(self) -> Iterator[int]:
        indice={}
        generator = torch.Generator()
        for way in range(5):
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator.manual_seed(seed)
            indice[way] = torch.randperm(len(self.indices), generator=generator)

        for (i,j,k,l,m) in zip(indice[0], indice[1], indice[2], indice[3], indice[4]):
            yield (int(self.indices[i]), int(self.indices[j]), int(self.indices[k]), int(self.indices[l]), int(self.indices[m]))

    def __len__(self) -> int:
        return len(self.indices)

class SubsetRandomSampler(Sampler[int]):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Args:
        indices (sequence): a sequence of indices
        generator (Generator): Generator used in sampling.
    """

    def __init__(self, dataset, limit_batches=200, batch_size=4) -> None:
        self.dataset = dataset
        self.limit_batches = limit_batches
        self.batch_size = batch_size
        self.total_batches = int(self.limit_batches*self.batch_size)

    def __iter__(self) -> Iterator[int]:
        indices = []
        for _ in range(self.total_batches):
            indices.append(tuple(random.sample(range(len(self.dataset.dataset.dataset)), 5)))
        indices = indices[:len(self.dataset)]
        return iter(indices)

    def __len__(self) -> int:
        return len(self.dataset)

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                #yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
                yield from [(torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator),
                torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator),
                torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator),
                torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator),
                torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator))]
            #yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
            yield from [(torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator),
            torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator),
            torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator),
            torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator),
            torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator))]
        else:
            for _ in range(self.num_samples // n):
                #yield from torch.randperm(n, generator=generator).tolist()
                yield from [(torch.randperm(n, generator=generator),
                torch.randperm(n, generator=generator),
                torch.randperm(n, generator=generator),
                torch.randperm(n, generator=generator),
                torch.randperm(n, generator=generator))]
            
            # yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
            yield from  [(torch.randperm(n, generator=generator),
            torch.randperm(n, generator=generator),
            torch.randperm(n, generator=generator),
            torch.randperm(n, generator=generator),
            torch.randperm(n, generator=generator))][:self.num_samples % n]

    def __len__(self) -> int:
        return self.num_samples


def jsd_loss(logit1, logit2):
    probs1 = F.softmax(logit1, dim=0)
    probs2 = F.softmax(logit2, dim=0)

    total_m = 0.5* (probs1+probs2)

    loss = 0.0
    loss += F.kl_div(F.log_softmax(logit1, dim=0), total_m, reduction="batchmean")
    loss += F.kl_div(F.log_softmax(logit2, dim=0), total_m, reduction="batchmean")

    return (0.5*loss)
  
class Logger:
    def __init__(
        self,
        exp_name,
        log_dir=None,
        exp_suffix="",
        write_textfile=True,
        use_wandb=False,
        wandb_project_name=None,
        entity='wandb'
    ):

        self.log_dir = './logs/'+log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        self.write_textfile = write_textfile
        self.use_wandb = use_wandb

        self.logs_for_save = {}
        self.logs = {}

        if self.write_textfile:
            self.f = open(os.path.join(self.log_dir, 'logs.txt'), 'w')

        if self.use_wandb:
            #exp_suffix = "_".join(exp_suffix.split("/")[:-1])
            self.run = wandb.init(
                config=wandb.config,
                entity=entity,
                project=wandb_project_name, 
                name=exp_name, #+ "_" + exp_suffix, 
                group=exp_name.split('net-')[0],
                reinit=True)
            
    def watch(self, model):
        if self.use_wandb:
            wandb.watch(model)
        
    def update_config(self, v, is_args=False):
        if is_args:
            self.logs_for_save.update({'args': v})
        else:
            self.logs_for_save.update(v)
        if self.use_wandb:
            wandb.config.update(v, allow_val_change=True)


    def write_log_nohead(self, element, step):
        log_str = f"{step} | "
        log_dict = {}
        for key, val in element.items():
            if not key in self.logs_for_save:
                self.logs_for_save[key] =  []
            self.logs_for_save[key].append(val)
            log_str += f'{key} {val} | '
            log_dict[f'{key}'] = val
        
        if self.write_textfile:
            self.f.write(log_str+'\n')
            self.f.flush()

        if self.use_wandb:
            wandb.log(log_dict, step=step)
            
    def write_log(self, element, step, img_dict=None, tbl_dict=None):
        log_str = f"{step} | "
        log_dict = {}
        for head, keys  in element.items():
            for k in keys:
                v = self.logs[k].avg
                if not k in self.logs_for_save:
                    self.logs_for_save[k] = []
                self.logs_for_save[k].append(v)
                log_str += f'{k} {v}| '
                log_dict[f'{head}/{k}'] = v

        if self.write_textfile:
            self.f.write(log_str+'\n')
            self.f.flush()
        
        if img_dict is not None:
            log_dict.update(img_dict)
        
        if tbl_dict is not None:
            log_dict.update(tbl_dict)
            
        if self.use_wandb:
            wandb.log(log_dict, step=step)


    def save_log(self, name=None):
        name = 'logs.pt' if name is None else name
        torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
    

    def update(self, key, v, n=1):
        if not key in self.logs:
            self.logs[key] = AverageMeter()
        self.logs[key].update(v, n)
    

    def reset(self, keys=None, except_keys=[]):
        if keys is not None:
            if isinstance(keys, list):
                for key in keys:
                    self.logs[key] =  AverageMeter()
            else:
                self.logs[keys] = AverageMeter()
        else:
            for key in self.logs.keys():
                if not key in except_keys:
                    self.logs[key] = AverageMeter()


    def avg(self, keys=None, except_keys=[]):
        if keys is not None:
            if isinstance(keys, list):
                return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
            else:
                return self.logs[keys].avg
        else:
            avg_dict = {}
            for key in self.logs.keys():
                if not key in except_keys:
                    avg_dict[key] =  self.logs[key].avg
            return avg_dict 



class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.value = 0
        self.average = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.value = 0
        self.average = 0
        self.sum = 0
        self.count = 0

    def update(self, value, n=1):
        self.value = value
        self.sum += value * n
        self.count += n
        self.average = self.sum / self.count


class AverageMeterList(object):
    """Computes and stores the average and current value of layer and bias importance"""

    def __init__(self, list_num):
        self.list_num = list_num
        self.avg_list = [AverageMeter() for _ in range(self.list_num)]

    def reset(self):
        self.avg_list = [AverageMeter() for _ in range(self.list_num)]

    def update(self, _avg_list, n=1):
        for i in range(self.list_num):
            self.avg_list[i].update(_avg_list[i], n)

    def return_average(self):
        avg_list_avg = [self.avg_list[i].average for i in range(self.list_num)]
        return avg_list_avg


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_checkpoint(logdir, mode='last'):
    model_path = os.path.join(logdir, f'{mode}.model')
    optim_path = os.path.join(logdir, f'{mode}.optim')
    config_path = os.path.join(logdir, f'{mode}.configs')
    lr_path = os.path.join(logdir, f'{mode}.lr')
    ema_path = os.path.join(logdir, f'{mode}.ema')

    print("=> Loading checkpoint from '{}'".format(logdir))
    if os.path.exists(model_path):
        model_state = torch.load(model_path)
        optim_state = torch.load(optim_path)
        import pickle5 as pickle
        with open(config_path, 'rb') as handle:
            cfg = pickle.load(handle)
    else:
        return None, None, None, None, None

    if os.path.exists(lr_path):
        lr_dict = torch.load(lr_path)
    else:
        lr_dict = None

    if os.path.exists(ema_path):
        ema_dict = torch.load(ema_path)
    else:
        ema_dict = None

    return model_state, optim_state, cfg, lr_dict, ema_dict


def save_checkpoint(P, step, best, model_state, optim_state, logdir, is_best=False):
    if is_best:
        prefix = 'best'
    else:
        prefix = 'last'
    last_model = os.path.join(logdir, f'{prefix}.model')
    last_optim = os.path.join(logdir, f'{prefix}.optim')
    last_config = os.path.join(logdir, f'{prefix}.configs')

    if isinstance(P.inner_lr, OrderedDict):
        last_lr = os.path.join(logdir, f'{prefix}.lr')
        torch.save(P.inner_lr, last_lr)
    if hasattr(P, 'moving_average'):
        last_ema = os.path.join(logdir, f'{prefix}.ema')
        torch.save(P.moving_average, last_ema)
    if hasattr(P, 'moving_inner_lr'):
        last_lr_ema = os.path.join(logdir, f'{prefix}.lr_ema')
        torch.save(P.moving_inner_lr, last_lr_ema)
    if hasattr(P, 'predictor'):
        last_predictor = os.path.join(logdir, f'{prefix}.predictor')
        torch.save(P.moving_average, last_predictor)

    opt = {
        'step': step,
        'best': best
    }
    torch.save(model_state, last_model)
    torch.save(optim_state, last_optim)
    with open(last_config, 'wb') as handle:
        pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)


def save_checkpoint_step(P, step, best, model_state, optim_state, logdir):
    last_model = os.path.join(logdir, f'step{step}.model')
    last_optim = os.path.join(logdir, f'step{step}.optim')
    last_config = os.path.join(logdir, f'step{step}.configs')

    if isinstance(P.inner_lr, OrderedDict):
        last_lr = os.path.join(logdir, f'step{step}.lr')
        torch.save(P.inner_lr, last_lr)
    if hasattr(P, 'moving_average'):
        last_ema = os.path.join(logdir, f'step{step}.ema')
        torch.save(P.moving_average, last_ema)
    if hasattr(P, 'moving_inner_lr'):
        last_lr_ema = os.path.join(logdir, f'step{step}.lr_ema')
        torch.save(P.moving_inner_lr, last_lr_ema)
    if hasattr(P, 'predictor'):
        last_predictor = os.path.join(logdir, f'step{step}.predictor')
        torch.save(P.moving_average, last_predictor)

    opt = {
        'step': step,
        'best': best
    }
    torch.save(model_state, last_model)
    torch.save(optim_state, last_optim)
    with open(last_config, 'wb') as handle:
        pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)


def cycle(loader):
    while True:
        for x in loader:
            yield x


def one_hot(ids, n_class):
    # ---------------------
    # author：ke1th
    # source：CSDN
    # artical：https://blog.csdn.net/u012436149/article/details/77017832
    """
    ids: (list, ndarray) shape:[batch_size]
    out_tensor:FloatTensor shape:[batch_size, depth]
    """

    assert len(ids.shape) == 1, 'the ids should be 1-D'

    out_tensor = torch.zeros(len(ids), n_class)

    out_tensor.scatter_(1, ids.cpu().unsqueeze(1), 1.)

    return out_tensor


class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=20):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece


def dist_gather(P, X):
    if P.distributed:
        Xs = [torch.zeros_like(X) for _ in range(P.world_size)]
        dist.all_gather(Xs, X)
        X = torch.cat(Xs, 0)

    return X


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if v is None:
                continue
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))
