from __future__ import absolute_import
import datetime
import shutil
from pathlib import Path
import os, time
import numpy as np
import matplotlib.pyplot as plt

import torch
import logging

from .options import args
from sklearn.cluster import AffinityPropagation
from sklearn.cluster.affinity_propagation_ import euclidean_distances
from sklearn.random_projection import SparseRandomProjection
'''
#label smooth
class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss
'''
"""Computes and stores the average and current value"""


class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def ensure_path(directory):
    directory = Path(directory)
    directory.mkdir(parents=True, exist_ok=True)


def mkdir(path):
    if not os.path.isdir(path):
        mkdir(os.path.split(path)[0])
    else:
        return
    os.mkdir(path)


'''Save model and record configurations'''


class checkpoint():

    def __init__(self, args):
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
        today = datetime.date.today()

        self.args = args
        self.job_dir = Path(args.job_dir)
        self.ckpt_dir = self.job_dir / 'checkpoint'
        self.run_dir = self.job_dir / 'run'

        if args.reset:
            os.system('rm -rf' + args.job_dir)

        def _make_dir(path):
            if not os.path.exists(path):
                #print("pathdonotexist")
                os.makedirs(path)

        _make_dir(self.job_dir)
        _make_dir(self.ckpt_dir)
        _make_dir(self.run_dir)

        config_dir = self.job_dir / 'config.txt'
        if not os.path.exists(config_dir):
            with open(config_dir, 'w') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')

    def save_model(self, state, epoch, is_best):
        save_path = f'{self.run_dir}/model.pt'
        # print('=> Saving model to {}'.format(save_path))
        torch.save(state, save_path)
        if is_best:
            shutil.copyfile(save_path, f'{self.ckpt_dir}/model_best.pt')


class get_logger():
    def __init__(self):
        self.logger = logging.getLogger('gal')
        self.log_format = '%(asctime)s | %(message)s'
    def add_logger(self, file_path):
        self.formatter = logging.Formatter(self.log_format, datefmt='%m/%d %I:%M:%S %p')
        self.file_handler = logging.FileHandler(file_path)
        self.file_handler.setFormatter(self.formatter)
        self.stream_handler = logging.StreamHandler()
        self.stream_handler.setFormatter(self.formatter)

        self.logger.addHandler(self.file_handler)
        self.logger.addHandler(self.stream_handler)
        self.logger.setLevel(logging.INFO)

    def remove_logger(self):
        self.logger.removeHandler(self.file_handler)
        self.logger.removeHandler(self.stream_handler)


"""Computes the precision@k for the specified values of k"""


def accuracy(output, target, topk=(1, )):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class RecorderMeter(object):
    """Computes and stores the minimum loss value and its epoch index"""

    def __init__(self, total_epoch):
        #    self.reset(total_epoch)
        #
        #  def reset(self, total_epoch):
        assert total_epoch > 0
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros((self.total_epoch, 2),
                                     dtype=np.float32)  # [epoch, train/val]
        self.epoch_losses = self.epoch_losses - 1

        self.epoch_accuracy = np.zeros((self.total_epoch, 2),
                                       dtype=np.float32)  # [epoch, train/val]
        self.epoch_accuracy = self.epoch_accuracy

    def update(self, idx, train_loss, train_acc, val_loss, val_acc):
        assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(
            self.total_epoch, idx)
        self.epoch_losses[idx, 0] = train_loss
        self.epoch_losses[idx, 1] = val_loss
        self.epoch_accuracy[idx, 0] = train_acc
        self.epoch_accuracy[idx, 1] = val_acc
        self.current_epoch = idx + 1
        return self.max_accuracy(False) == val_acc

    def max_accuracy(self, istrain):
        if self.current_epoch <= 0: return 0
        if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
        else: return self.epoch_accuracy[:self.current_epoch, 1].max()

    def plot_curve(self, save_path):
        title = 'the accuracy/loss curve of train/val'
        dpi = 1000

        fig = plt.figure()
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs
        y_axis = np.zeros(self.total_epoch)

        plt.xlim(0, self.total_epoch)
        plt.ylim(0, 100)
        interval_y = 5
        interval_x = 10
        plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
        plt.yticks(np.arange(0, 100 + interval_y, interval_y))
        #    plt.grid()
        plt.title(title)
        plt.xlabel('the training epoch')
        plt.ylabel('accuracy')

        y_axis[:] = self.epoch_accuracy[:, 0]
        plt.plot(x_axis,
                 y_axis,
                 color='g',
                 linestyle='-',
                 label='train-accuracy',
                 lw=2)
        plt.legend(loc=4)

        y_axis[:] = self.epoch_accuracy[:, 1]
        plt.plot(x_axis,
                 y_axis,
                 color='y',
                 linestyle='-',
                 label='valid-accuracy',
                 lw=2)
        plt.legend(loc=4)

        y_axis[:] = self.epoch_losses[:, 0]
        plt.plot(x_axis,
                 y_axis * 50,
                 color='g',
                 linestyle=':',
                 label='train-loss-x50',
                 lw=2)
        plt.legend(loc=4)

        y_axis[:] = self.epoch_losses[:, 1]
        plt.plot(x_axis,
                 y_axis * 50,
                 color='y',
                 linestyle=':',
                 label='valid-loss-x50',
                 lw=2)
        plt.legend(loc=4)

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
            print('---- save figure {} into {}'.format(title, save_path))
        plt.close(fig)

    def plot_curve_imagenet(self, save_path):
        title = 'the accuracy/loss curve of train/val'
        dpi = 1000

        fig = plt.figure()
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs

        ax1 = fig.add_subplot(111)
        plot1 = ax1.plot(x_axis,
                         self.epoch_losses[:, 0],
                         'b:',
                         label='train-loss',
                         lw=2)
        plot2 = ax1.plot(x_axis,
                         self.epoch_losses[:, 1],
                         'r:',
                         label='valid-loss',
                         lw=2)
        ax1.set_xlim([0, self.total_epoch])
        ax1.set_ylim([0, 10.0])
        ax1.set_ylabel('loss')
        ax1.set_title(title)

        ax2 = ax1.twinx()  # this is the important function
        plot3 = ax2.plot(x_axis,
                         self.epoch_accuracy[:, 0],
                         'b-',
                         label='train-accuracy',
                         lw=2)
        plot4 = ax2.plot(x_axis,
                         self.epoch_accuracy[:, 1],
                         'r-',
                         label='valid-accuracy',
                         lw=2)
        ax2.set_ylim([0, 100])
        ax2.set_ylabel('accuracy(%)', rotation=270, labelpad=13)
        ax2.set_xlabel('epochs')
        lines = plot1 + plot2 + plot3 + plot4
        ax2.legend(lines, [l.get_label() for l in lines], loc='upper left')

        fig.tight_layout()

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
            print('---- save figure {} into {}'.format(title, save_path))
        plt.close(fig)


def time_string():
    ISOTIMEFORMAT = '%Y-%m-%d %X'
    string = '[{}]'.format(
        time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
    return string


def convert_secs2time(epoch_time):
    need_hour = int(epoch_time / 3600)
    need_mins = int((epoch_time - 3600 * need_hour) / 60)
    need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
    return need_hour, need_mins, need_secs

def cluster_weight(weight, beta=None):

    if beta is None:
        beta = args.preference_beta
    A = weight.cpu().clone()
    if weight.dim() == 4:  #Convolution layer
        A = A.view(A.size(0), -1)
    else:
        raise('The weight dim must be 4!!!')

    affinity_matrix = -euclidean_distances(A, squared=True)
    preference = np.median(affinity_matrix, axis=0) * beta
    cluster = AffinityPropagation(preference=preference)
    cluster.fit(A)
    return cluster.labels_, cluster.cluster_centers_, cluster.cluster_centers_indices_


def random_project(weight, channel_num):

    A = weight.cpu().clone()
    A = A.view(A.size(0), -1)
    rp = SparseRandomProjection(n_components=channel_num * weight.size(2) * weight.size(3))
    rp.fit(A)
    return rp.transform(A)

def direct_project(weight, indices):

    A = torch.randn(weight.size(0), len(indices), weight.size(2), weight.size(3))
    for i, indice in enumerate(indices):

        A[:, i, :, :] = weight[:, indice, :, :]

    return A
