import torch
import torch.distributed as dist


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.sum, self.val], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.sum = t[1]
        self.val = t[2]
        self.avg = self.sum / self.count


import numpy as np
import matplotlib.pyplot as plt


# continual learning metric
class ContinaulMetric(object):
    """Computes and stores the average and current value for model at every learning time"""

    def __init__(self, args):
        self.task = args.split
        self.mode = args.setting
        if self.mode == 'class_incremental':
            self.dtask = self.task
        elif self.mode == 'data_incremental':
            self.dtask = 1
        self.reset()
        self.set_task()

    def reset(self):
        self.task_avg = 0
        self.bwt = 0
        self.tla = 0
        self.sample_avg = 0
        self.matrix = np.zeros([self.task, self.dtask])
        self.current_last_task = 0

    def set_task(self):
        self.metrics = {}
        self.weighted = {}
        for t in range(self.task):
            self.metrics[t] = {}
            for dt in range(self.dtask):
                self.metrics[t][dt] = AverageMeter()
            self.weighted[t] = AverageMeter()

    def update(self, model_no, data_no, val, n=1):
        self.current_last_task = max(model_no, self.current_last_task)
        self.metrics[model_no][data_no].update(val, n)
        self.weighted[model_no].update(val, n)

    def update_metric(self, model_no, data_no):
        # self.metrics[model_no][data_no].synchronize_between_processes()
        self.matrix[model_no, data_no] = self.metrics[model_no][data_no].avg
        if self.mode == 'class_incremental':
            self.task_avg = np.mean(self.matrix[self.current_last_task, :self.current_last_task + 1])
            self.tla = np.mean(np.diag(self.matrix[:self.current_last_task + 1]))
            if self.current_last_task > 0:
                self.bwt = sum(
                    np.max(self.matrix[:self.current_last_task, :self.current_last_task], axis=0) - self.matrix[
                                                                                                    self.current_last_task,
                                                                                                    :self.current_last_task]) / self.current_last_task

        self.sample_avg = np.mean([self.weighted[i].avg for i in range(self.current_last_task+1)])

    def print_matrix(self, name):
        print(f'Metric {name}')
        with np.printoptions(precision=3, suppress=True):
            print(self.matrix)

    def wandb_log_matrix(self, name, args):
        if self.mode == 'class_incremental':
            plt.figure()

            plt.matshow(self.matrix, cmap='seismic')
            plt.xlabel('data of different task')
            plt.ylabel('model trained after different task')
            for (y, x), value in np.ndenumerate(self.matrix):
                plt.text(x, y, f"{value:.2f}", va="center", ha="center")

            args.writer.log({f'task{self.current_last_task}/accuracy_matrix_{name}': plt})
            plt.clf()


class ChangePointDetecotr(object):
    def __init__(self):
        self.points = []
        self.smoothed_point = []
        self.change = []
        self.std = []

    def update(self, point):
        self.points.append(point)
        self.smooth()
        return self.checkchange()

    def smooth(self, w: int = 10):

        if len(self.points) > w:
            loss_list = np.array(self.points[-w:])
            self.smoothed_point.append(np.mean(loss_list))
            self.std.append(np.std(loss_list))

    def checkchange(self):
        if len(self.smoothed_point) <= 1:
            return 0
        else:
            self.change.append(self.smoothed_point[-2] - self.smoothed_point[-1])
            if self.change[-1] < max(self.change) / 5 or self.std[-1] > max(self.std) * 5:
                return 1
            else:
                return 0
