import numpy as np
from scipy import sparse
from abc import ABCMeta, abstractclassmethod
import multiprocessing
from joblib import Parallel, delayed
import sys, os, datetime, matplotlib
import matplotlib.pyplot as plt
import graphlearning as gl


class v_poisson(gl.ssl.ssl):
    def __init__(self, W=None, var='weighted', mode='accum', class_priors=None, lamb=1.0, use_cuda=False, min_iter=50, max_iter=1000):
        super().__init__(W, class_priors)
        self.lamb = lamb
        self.use_cuda = use_cuda
        self.min_iter = min_iter
        self.max_iter = max_iter
        self.var = var
        self.mode = mode
        
        # Setup accuracy filename
        fname ='_v_poisson'
        
        self.accuracy_filename = fname
        # Setup Algoritm name
        self.name = 'v_poisson with lamb=%.4f' % lamb
    
    def _fit(self, train_ind, train_labels, all_labels=None):
        n = self.graph.num_nodes
        unique_labels = np.unique(train_labels)
        k = len(unique_labels)
        
        # Zero-out diagonal for faster convergence
        W = self.graph.weight_matrix
        W = W - sparse.spdiags(W.diagonal(), 0, n, n)
        G = gl.graph(W)
        
        # source term
        onehot = gl.utils.labels_to_onehot(train_labels)
        source = np.zeros((n, k))
        # source[train_ind] = onehot
        source[train_ind] = onehot - np.mean(onehot, axis=0)
        
        # Setup matrices
        D = G.degree_matrix(p=1)
        D_inv = G.degree_matrix(p=-1)
        P = D_inv * W.transpose()
        
        Db = D_inv * source
        
        # Invariant Distribution
        v = np.zeros(n)
        v[train_ind] = 1
        v = v / np.sum(v)
        deg = G.degree_vector()
        vinf = deg / np.sum(deg)
        RW = W.transpose() * D_inv
        
        
        ut = np.zeros((n, k))
        # Number of iterations
        if self.use_cuda:
            import torch
            D = gl.utils.torch_sparse(D).cuda()
            P = gl.utils.torch_sparse(P).cuda()
            ut = torch.from_numpy(ut).float().cuda()
            train_ind = torch.LongTensor(train_ind).cuda()
            onehot = torch.from_numpy(onehot).float().cuda()
            source = torch.from_numpy(source).float().cuda()
            Db = torch.from_numpy(Db).float().cuda()
        
        T = 0
        tmp = 0
        best_acc = 0
        while (T < self.min_iter or np.max(np.absolute(v - vinf)) > 1/n) and (T < self.max_iter):
            if self.mode == 'scale':
                if self.var == 'mean':
                    ut_avg = ut.mean(0)
                else:
                    if self.use_cuda:
                        ut_avg = torch.sparse.mm(D, ut).sum(0) / torch.sparse.sum(D)
                    else:
                        ut_avg = (D * ut).sum(0) / D.sum()
                    
                if self.use_cuda:
                    ut = torch.sparse.addmm(Db, P, ut) - self.lamb * ut_avg
                else:
                    ut = P * ut + Db - self.lamb * ut_avg
                ut = ut / (1 - self.lamb)
                
            elif self.mode == 'accum':
                if self.var == 'mean':
                    ut_var = ut - ut.mean(0)
                else:
                    if self.use_cuda:
                        ut_var = ut - torch.sparse.mm(D, ut).sum(0) / torch.sparse.sum(D)
                    else:
                        ut_var = ut - (D * ut).sum(0) / D.sum()
                if self.use_cuda:
                    ut = torch.sparse.addmm(Db, P, ut) + self.lamb * ut_var
                else:
                    ut = P * ut + Db + self.lamb * ut_var
            
            v = RW * v
            T += 1
            u = ut.cpu().numpy() if self.use_cuda else ut
            
            
            # Compute accuracy if all labels are provided
            if all_labels is not None:
                self.prob = u
                print(self.prob.max(), self.prob.min())
                labels = self.predict()
                acc = gl.ssl.ssl_accuracy(labels, all_labels, len(train_ind))
                if best_acc < acc:
                    tmp = T
                    best_acc = max(acc, best_acc)
                print('%d,Accuracy = %.2f'%(T, acc), 'Mean of u is', u.mean() / u.max())
        return u
    

class v_poisson_mbo(gl.ssl.ssl):
    def __init__(self, W=None, class_priors=None, var='mean', mode='scale',lamb=1.0, use_cuda=False, min_iter=50, max_iter=1000, Ns=40, mu=1, T=20):
        super().__init__(W, class_priors)
        self.v_poisson_model = v_poisson(W, lamb=lamb, var=var, mode=mode, use_cuda=use_cuda, min_iter=min_iter, max_iter=max_iter)

        self.Ns = Ns
        self.mu = mu
        self.use_cuda = use_cuda
        self.T = T
        self.lamb = lamb
        
        # Setup accuracy filename
        fname ='_v_poisson_mbo'
        
        self.accuracy_filename = fname
        # Setup Algoritm name
        self.name = 'v_poisson_mbo with lamb=%.2f' % lamb
    
    def _fit(self, train_ind, train_labels, all_labels=None):
        
        n = self.graph.num_nodes
        unique_labels = np.unique(train_labels)
        k = len(unique_labels)
        
        # Zero-out diagonal for faster convergence
        W = self.graph.weight_matrix
        W = W - sparse.spdiags(W.diagonal(), 0, n, n)
        G = gl.graph(W)
        
        # source term
        onehot = gl.utils.labels_to_onehot(train_labels)
        source = np.zeros((n, k))
        source[train_ind] = onehot - np.mean(onehot, axis=0)
        
        # mask-out unlabeled data
        mask = np.zeros((n, k))
        mask[train_ind] = 1
        
        # Setup matrices
        D = G.degree_matrix(p=1)
        D_inv = G.degree_matrix(p=-1)
        P = D_inv * W.transpose()
        
        Db = D_inv * source
        
        # Initialize via V-Poisson Learning
        labels = self.v_poisson_model.fit_predict(train_ind, train_labels, all_labels=all_labels)
        u = gl.utils.labels_to_onehot(labels)
        
        # Time step for stability
        dt = 1 / np.max(G.degree_vector())
        
        # Precompute some things
        P = sparse.identity(n) - dt * G.laplacian()
        
        if self.use_cuda:
            import torch
            D = gl.utils.torch_sparse(D).cuda()
            P = gl.utils.torch_sparse(P).cuda()
            train_ind = torch.LongTensor(train_ind).cuda()
            onehot = torch.from_numpy(onehot).float().cuda()
            Db = torch.from_numpy(Db).float().cuda()
            source = torch.from_numpy(source).float().cuda()
            mask = torch.from_numpy(mask).float().cuda()
        
        for i in range(self.T):
            ut = torch.from_numpy(u).float().cuda() if self.use_cuda else u
            for j in range(self.Ns):
                ut = torch.sparse.mm(P, ut) if self.use_cuda else P * ut
                ut += dt * Db
            
            u = ut.cpu().numpy() if self.use_cuda else ut
            # Projection step
            self.prob = u
            labels = self.volume_label_projection()
            u = gl.utils.labels_to_onehot(labels)

            #Compute accuracy if all labels are provided
            if all_labels is not None:
                acc = gl.ssl.ssl_accuracy(labels,all_labels,len(train_ind))
                print('%d, Accuracy = %.2f'%(i,acc))

        return u