from model.losses import *
import torch
from sklearn.cluster import KMeans
from model.evaluation import *
from model.encoder import *
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import os
import scipy
import torch.nn.functional as F


def trans_adj_to_edge_index(adj):
    adj = scipy.sparse.coo_matrix(adj.cpu().detach().numpy())
    values = adj.data
    indices = np.vstack((adj.row, adj.col))
    adj = torch.LongTensor(indices)
    return adj


class GOD_Trainer:
    def __init__(self, args, model, optimizer, alpha, beta, gamma, device):
        """Initialize GOD Trainer

        Parameters
        ----------
        model: Graph Outlier Detector.
        optimizer: Optimizer of joint model.
        device: torch.device object for device to use.
        alpha: Weight of KL-Divergence loss.
        beta: Weight of Distribution Repulsion loss.
        gamma: Weight of Diversity loss.
        args: The arguments of network, such as latent dimension.
        """
        self.model = model
        self.optimizer = optimizer
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.device = device
        self.args = args
        self.penalty_loss = TwoClusterEmptyPenalty(beta = 1.0)

    def train(self, data, adj, anomaly_flag, total_epochs):
        """
        Training function
        Parameters
        ----------
        data: Test Data.
        adj: Adjacency Matrix.
        anomaly_flag: Labels.
        total_epochs: Total epochs for training.
        """
        best_auc = -np.inf
        test_auc = -np.inf
        log_interval = 1

        pbar = tqdm(range(1, total_epochs + 1))
        for epoch in pbar:
        # for epoch in range(1,args.epochs + 1):
            if epoch == 1:
                self.model.eval()
                _, emb, tmp_q = self.model.encode(data.x, adj, mode='train')

                kmeans = KMeans(n_clusters=2, n_init=100)
                y_pred = kmeans.fit_predict(emb.cpu().detach().numpy())

                self.model.encoder.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(self.device)
                p = target_distribution(torch.tensor(tmp_q)).to(self.device)
                self.model.train()

            if epoch % log_interval == 0:
                test_auc, p, emb, scores = self.test(data, adj, anomaly_flag)
                if test_auc > best_auc:
                    best_auc = test_auc
                self.model.train()
            self.optimizer.zero_grad()
            first_emb, z, tmp_q = self.model.encode(data.x, adj, mode='train')
            first_emb = z
            # first_emb = first_emb[0] # optional, if out of memory, select first_emb = z
            y_pred = tmp_q.argmax(1) 
            A_tilde = self.model.encoder.masked_A


            tmp_adj = trans_adj_to_edge_index(A_tilde)

            pos_loss, neg_loss, attr_loss = self.model.recon_loss(z, adj, tmp_adj, x=data.x)
            kl_loss = F.kl_div(tmp_q.log(), p)
            reg_loss = self.penalty_loss(tmp_q)
            
            ab_idx, normal_idx, _, _ = get_candidate(data, y_pred, z, self.model.encoder.cluster_layer.data)
            if len(ab_idx)==0:
                distribution_constraint = 0
            else:
                distribution_constraint = -chebyshev_mmd_loss(first_emb[ab_idx,:], first_emb[normal_idx,:])

            loss = (pos_loss + neg_loss + attr_loss.mean()) +self.beta * distribution_constraint + self.alpha * kl_loss + self.gamma * reg_loss
            
            
            loss.backward()
            self.optimizer.step()
            pbar.set_description("Epoch {}| Total Loss: {:.4}, Reconstruction Loss {:.4}, Distribution Loss {:.4}, Clustering Loss {:.4}, Regularization Loss {:.4}".format(
                epoch,
                loss.item(),
                (pos_loss + neg_loss)+attr_loss.mean(),
                self.beta * distribution_constraint,
                self.alpha * kl_loss,
                self.gamma * reg_loss
                )
            )
        return best_auc

    def test(self, data, adj, anomaly_flag):
        """Evaluate the model
        Parameters
        ----------
        data: Testing Data.
        adj: Adjacency Matrix.
        anomaly_flag: Labels.
        """
        self.model.eval()
        _, emb, tmp_q = self.model.encode(data.x, adj)
        
        y_pred = tmp_q.argmax(1)
        p = target_distribution(torch.tensor(tmp_q)).to(self.device)
        
        tmp_scores = tmp_q[:,0].cpu().detach().numpy()
        scores = np.ones_like(tmp_scores)-tmp_scores
        test_auc = roc_auc_score(anomaly_flag, scores)
        return test_auc, p, emb, scores

    def save(self, path):
        torch.save({'model': self.model.state_dict()}, os.path.join(path))


