from base.base_trainer import BaseTrainer
from sklearn.metrics.cluster import normalized_mutual_info_score
from utils.evaluation import acc
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from sklearn.cluster import KMeans
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
import logging
import time
import torch
import torch.optim as optim
import numpy as np
import random



def compute_balance(n_clusters, y_pred, yt):
    balance = []
    size_list = []
    nbins = len(np.unique(yt))
    print(np.size(np.unique(y_pred)))
    for k in range(n_clusters):
        idx = np.where(y_pred == k)
        y_k = np.squeeze(yt[idx])
        cluster_size = np.size(idx)
        hist = np.zeros((nbins,))
        sum_v = 0
        for i in range(nbins):
            val = i
            hist[i] = np.size(np.where(y_k == val))
            sum_v += hist[i]
        size_list.append(sum_v)
        if cluster_size>0:
            p_rule = np.min(hist) / np.max(hist)
            balance.append(p_rule)
    print("balance list:", balance)
    print("size list:", size_list)
    return np.amin(balance)

class RIMTrainer(BaseTrainer):

    def __init__(self, objective, optimizer_name: str = 'adam', lr: float = 0.001, n_epochs: int = 150, lr_milestones: tuple = (),
                 batch_size: int = 128, weight_decay: float = 1e-6, device: str = 'cuda', n_jobs_dataloader: int = 0):
        super().__init__(optimizer_name, lr, n_epochs, lr_milestones, batch_size, weight_decay, device,
                         n_jobs_dataloader)
        # Results
        self.train_time = None
        self.test_auc = None
        self.test_time = None
        self.test_scores = None

    def train(self, dataset: BaseADDataset, ae_net: BaseNet):
        logger = logging.getLogger()
        from torch import nn
        # Set device for network
        ae_net = ae_net.to(self.device)
        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)
        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        for element in list(ae_net.parameters()):
            print(element.shape)
            
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        ae_net.train()
        
        for epoch in range(self.n_epochs):
            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()            
            idx_label_score = []
            for data in train_loader:
                inputs, labels, idx, psvs = data
                labels = labels.to(self.device)
                inputs = inputs.to(self.device)
                optimizer.zero_grad()

                # forward
                scores = F.softmax(ae_net(inputs), dim=1)
                aver_entropy, entropy_aver = ae_net.compute_entropy(inputs)
                loss = aver_entropy - entropy_aver
                loss.backward()
                optimizer.step()
                loss_epoch += loss.item()
                n_batches += 1
                # Save triple of (idx, label, score) in a list
                idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                            labels.cpu().data.numpy().tolist(),
                                            scores.cpu().data.numpy().tolist(),
                                            psvs.cpu().data.numpy().tolist()))

            _, labels, scores, psvs = zip(*idx_label_score)
            labels = np.array(labels)
            scores = np.array(scores)

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'
                        .format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches))
            y_pred = np.argmax(scores, axis=1)
            y = labels
            if y is not None:
                print("acc: %.5f, nmi: %.5f" % (acc(y, y_pred), normalized_mutual_info_score(y, y_pred)))
                final_acc = acc(y, y_pred)
                final_nmi = normalized_mutual_info_score(y, y_pred)
            logger.info('Train set nmi: {:.2f}%'.format(100. * final_nmi))
            logger.info('Train set acc: {:.2f}%'.format(100. * final_acc))

        self.train_time = time.time() - start_time
        logger.info('Pretraining time: %.3f' % self.train_time)
        logger.info('Finished pretraining.')

        return ae_net

    def test(self, dataset: BaseADDataset, ae_net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        ae_net = ae_net.to(self.device)

        # Get test data loader
        _, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Testing
        logger.info('Testing')
        n_batches = 0
        start_time = time.time()
        idx_label_score = []
        ae_net.eval()
        with torch.no_grad():
            for data in test_loader:
                inputs, labels, idx, psv = data
                inputs = inputs.to(self.device)
                outputs = ae_net(inputs)
                scores = F.softmax(outputs, dim=1)
                # Save triple of (idx, label, score) in a list
                idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                            labels.cpu().data.numpy().tolist(),
                                            scores.cpu().data.numpy().tolist(),
                                            psv.cpu().data.numpy().tolist()))

                n_batches += 1


        _, labels, scores, psvs = zip(*idx_label_score)
        labels = np.array(labels)
        scores = np.array(scores)
        print("scores:", scores)
        self.test_scores = idx_label_score
        # run hungarin algorithm to transform scores into clustering assignments
        # validate the scores
        # log the performance
        y_pred = np.argmax(scores, axis=1)
        y = labels
        print(y_pred, y_pred.shape)
        if y is not None:
            print("acc: %.5f, nmi: %.5f" % (acc(y, y_pred), normalized_mutual_info_score(y, y_pred)))
            final_acc = acc(y, y_pred)
            final_nmi = normalized_mutual_info_score(y, y_pred)

        logger.info('Test set nmi: {:.2f}%'.format(100. * final_nmi))
        logger.info('Test set acc: {:.2f}%'.format(100. * final_acc))

        #logger.info('Test set AUC: {:.2f}%'.format(100. * self.test_auc))

        self.test_time = time.time() - start_time
        logger.info('Deep_RIM testing time: %.3f' % self.test_time)
        logger.info('Finished testing Deep_RIM.')
