import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
import scipy.sparse as sp
from metrics import clustering_metrics

from sklearn.preprocessing import normalize
from sklearn.metrics import accuracy_score, f1_score, precision_score
from utils import (preprocess_adj, ProcessKNN, square_dist,
                   to_onehot, agc_labels_to_optimized_graph,
                   distance_to_centroids)


class Softmax(torch.nn.Module):
    "custom softmax module"

    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.linear = torch.nn.Linear(self.n_inputs, self.n_outputs, bias=False)

    def forward(self, x):
        # print(x.shape, (self.n_inputs, self.n_outputs))
        # print(self.linear.in_features, self.linear.out_features)
        # exit(0)
        pred = self.linear(x)
        return pred


class Data2Torch(Dataset):
    "The data for multi-class classification"

    def __init__(self, x, y):
        # single input
        self.x = torch.from_numpy(x)
        # multi-class output
        self.y = torch.from_numpy(y)
        self.y = self.y.type(torch.LongTensor)
        self.len = self.x.shape[0]

    def __getitem__(self, idx):
        "accessing one element in the dataset by index"
        return self.x[idx], self.y[idx]

    def __len__(self):
        "size of the entire dataset"
        return self.len


def SGC_torch(feats, adjacency, pow, mask_test, mask_train, gnd_lab):
    print(f' --- SGC_torch Labels --- ')

    filter = preprocess_adj(adjacency)
    filter = (1 - 0.5) * sp.eye(filter.shape[0]) + 0.5 * filter

    for P in range(pow):
        feats = filter.dot(feats)

    if isinstance(feats, np.matrix):
        feats = np.array(feats)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    train_data = Data2Torch(X_train, y_train)
    test_data = Data2Torch(X_test, y_test)

    clf = Softmax(X_train.shape[-1], y_train.shape[-1]).float()
    clf.state_dict()

    # define loss, optimizier, and dataloader
    optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    train_loader = DataLoader(dataset=train_data, batch_size=200)

    # Train the model
    Loss = []
    epochs = 100
    for epoch in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            y_pred = clf(x.float())
            loss = criterion(y_pred, y)
            Loss.append(loss)
            loss.backward()
            optimizer.step()
    # print("Done!")

    # Make predictions on test data
    pred_model = clf(test_data.x.float())
    _, y_pred = pred_model.max(1)
    # # print("model predictions on test data:", y_pred)

    # # check model accuracy
    # correct = (test_data.y == y_pred).sum().item()
    # accuracy_sgc_torch = correct / len(test_data)
    # print("model accuracy: ", accuracy_sgc_torch)

    accuracy_sgc_torch = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_sgc_torch = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_sgc_torch = precision_score(y_test, y_pred, average='micro')
    loss_sgc_torch = loss

    print('loss_SGC_torch : {}, acc_SGC_torch: {}'.format(loss_sgc_torch,
                                                          accuracy_sgc_torch))
    return accuracy_sgc_torch, f1_sc_sgc_torch, precision_sgc_torch, loss_sgc_torch


def SGC_AGC_torch(kmeans_sgc_agc, predict_labels, u, feats, adjacency, gnd_lab, mask_test,
                  mask_train, prop_num=50, alpha=[0.5, 0.2, 0.3], keep=2, message='only',
                  knngraph=None, modify_knn=False):

    print(f' --- SGC_AGC_torch_{message}, power:{prop_num}, alphas: {alpha}  --- ')

    filter = preprocess_adj(adjacency)
    filter = (1 - 0.5) * sp.eye(filter.shape[0]) + 0.5 * filter

    if knngraph is not None:
        adjnormal = knngraph
        if modify_knn:
            adjnormal = ProcessKNN(mask_test, knngraph)
    else:
        if keep == 1:
            # build sparse adj
            adjnormal = agc_labels_to_optimized_graph(predict_labels, adjacency.shape)

        else:
            # taking top k for adj norm
            print('modify for selecting more than 1 smallest')
            exit(0)
            adjnormal_old = distance_to_centroids(kmeans_sgc_agc, u, predict_labels, dist_type=1, keep=keep)

    # preparing for adj
    idx = mask_test
    y = to_onehot(gnd_lab).T
    y[idx, :] = 0

    adjnormal = preprocess_adj(adjnormal)
    adjnormal2 = preprocess_adj(y.dot(y.T) + sp.eye(adjacency.shape[0]))
    # adjnormal3 = preprocess_adj(adjacency)
    # adjnormal = alpha[0] * adjnormal + alpha[1] * adjnormal2 + alpha[2] * adjnormal3
    adjnormal = alpha[0] * adjnormal + alpha[1] * adjnormal2
    adjnormal = (1 - 0.5) * sp.eye(feats.shape[0]) + 0.5 * adjnormal

    if isinstance(feats, np.matrix):
        feats = np.array(feats)

    for P in range(prop_num):
        feats = filter.dot(feats)

    X_train = feats[mask_train, :]
    X_test = feats[mask_test, :]
    y_train = gnd_lab[mask_train]
    y_test = gnd_lab[mask_test]

    class CustomLoss(torch.nn.Module):
        def __init__(self, laplacian, alpha=0.5, out_dim=7, train_idx=None):
            super(CustomLoss, self).__init__()
            self.laplacian = torch.from_numpy(laplacian.todense()).float()
            self.out_in_mat = torch.zeros((laplacian.shape[0], out_dim)).float()
            self.alpha = alpha
            self.train_idx = train_idx

        def forward(self, output, target):
            target = torch.LongTensor(target)
            criterion = torch.nn.CrossEntropyLoss()
            loss = criterion(output, target)
            # mask = target == 9
            # high_cost = (loss * mask.float()).mean()
            # return loss + high_cost
            # output = output.t
            self.out_in_mat[self.train_idx] = output
            self.out_in_mat_r = torch.transpose(self.out_in_mat, 1, 0)
            loss = loss + \
                   self.alpha * torch.trace(torch.matmul(self.out_in_mat_r,
                                                         torch.matmul(self.laplacian,
                                                                      self.out_in_mat)))
            # print(loss)
            # exit(0)
            return loss

    train_data = Data2Torch(X_train, y_train)
    test_data = Data2Torch(X_test, y_test)

    clf = Softmax(X_train.shape[-1], y_train.shape[-1])
    clf.state_dict()

    # define loss, optimizier, and dataloader
    optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)
    # criterion = torch.nn.CrossEntropyLoss()
    criterion = CustomLoss(laplacian=adjnormal, alpha=0.5,
                           out_dim=y_train.shape[-1],
                           train_idx=mask_train)
    train_loader = DataLoader(dataset=train_data, batch_size=200)

    # Train the model
    Loss = []
    epochs = 100
    for epoch in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            y_pred = clf(x.float())
            loss = criterion(y_pred, y)
            Loss.append(loss)
            loss.backward(retain_graph=True)
            optimizer.step()
    # print("Done!")

    # Make predictions on test data
    pred_model = clf(test_data.x.float())
    _, y_pred = pred_model.max(1)
    # # print("model predictions on test data:", y_pred)

    # # check model accuracy
    # correct = (test_data.y == y_pred).sum().item()
    # accuracy_sgc_agc_torch = correct / len(test_data)
    # print("model accuracy: ", accuracy_sgc_agc_torch)

    accuracy_sgc_agc_torch = accuracy_score(y_test, y_pred, normalize=True)
    f1_sc_sgc_agc_torch = f1_score(y_test, y_pred, zero_division=1, average='micro')
    precision_sgc_agc_torch = precision_score(y_test, y_pred, average='micro')
    loss_sgc_agc_torch = loss

    print('loss_SGC_AGC_torch : {}, acc_SGC_AGC_torch: {}'.format(loss_sgc_agc_torch,
                                                                  accuracy_sgc_agc_torch))
    return accuracy_sgc_agc_torch, f1_sc_sgc_agc_torch, \
        precision_sgc_agc_torch, loss_sgc_agc_torch


def Fewshot_preprocess(filter, feats, pow=12, norm=None):
    for p in range(pow):
        feats = filter.dot(feats)

    if norm:
        feats = normalize(feats, norm=norm, axis=1)

    return feats


def Symm_NMF(symmilarity, iterations, num_clus, gnd_lab, step):
    print(f' --- FewShot using Symm NMF --- ')
    np.random.seed(0)
    H = np.random.random((symmilarity.shape[0], num_clus))

    for it in range(1, iterations + 1):
        H = H * step * (symmilarity.dot(H) / (H.dot(H.T.dot(H))))
        predict_labels = np.argmax(H, axis=1)
        # print(H.min(), H.shape, predict_labels.shape, np.unique(predict_labels).shape)

        cm = clustering_metrics(gnd_lab, predict_labels)
        acc, _, _ = cm.evaluationClusterModelFromLabel()
        print(f"iteration: {it}, acc: {acc}")
    exit(0)

    return predict_labels


def SVD_soft_clusters(feats, num_clus, gnd_lab):
    print(f' --- FewShot using SVD --- ')

    u, s, v = sp.linalg.svds(feats, k=num_clus, which='LM')
    # u = feats.dot(v.T)
    # print(v.shape)
    # exit(0)
    predict_labels = np.argmax(u, axis=1)
    print(predict_labels.shape, np.unique(predict_labels).shape)

    cm = clustering_metrics(gnd_lab, predict_labels)
    acc, _, _ = cm.evaluationClusterModelFromLabel()
    print(f"acc: {acc}")
    exit(0)

    return predict_labels


def NMF(feats, iterations, num_clus, gnd_lab, step):
    print(f' --- FewShot using NMF --- ')
    print(f' --- CHECK IMPLEMENTATION !! --- ')
    exit(0)

    np.random.seed(0)
    H = np.random.random((num_clus, feats.shape[1]))
    W = np.random.random((feats.shape[0], num_clus))

    for it in range(1, iterations + 1):
        H = H * step * (W.T.dot(feats) / (W.T.dot(W.dot(H))))
        W = W * step * ((feats.dot(H.T)) / (W.dot(H.dot(H.T))))
        predict_labels = np.argmax(W, axis=1)
        # print(predict_labels.shape, np.unique(predict_labels).shape)

        cm = clustering_metrics(gnd_lab, predict_labels)
        acc, _, _ = cm.evaluationClusterModelFromLabel()
        print(f"iteration: {it}, acc: {acc}")
        # exit(0)

    return predict_labels


def FewShot(feats, filter, gnd_lab, norm='l2', num_clus=7, pow=12,
            iterations=300, step=0.5, type='symmNMF'):
    feats = Fewshot_preprocess(filter, feats, pow, norm)

    if type == 'symmNMF':
        symm = feats.dot(feats.T)
        predict_labels = Symm_NMF(symm, iterations, num_clus, gnd_lab, step)
    elif type == 'svd':
        predict_labels = SVD_soft_clusters(feats, num_clus, gnd_lab)
    elif type == 'NMF':
        predict_labels = NMF(feats, iterations, num_clus, gnd_lab, step)
    else:
        print(f'type: {type} NOT Implemented!!')

    cm = clustering_metrics(gnd_lab, predict_labels)
    accuracy_symm_NMF, nmi_symm_NMF, f1_sc_symm_NMF = cm.evaluationClusterModelFromLabel()

    print(f'Symm_NMF iter{iterations}, acc: {accuracy_symm_NMF}')

    return accuracy_symm_NMF, nmi_symm_NMF, f1_sc_symm_NMF


def cands_agc(kmeans_cands_agc, predict_labels, u, adjacency, gnd_lab, mask_test, p=0.6,
              correct_num=50, smooth_num=50, lp_p=0.5, lp_alpha=0.5, alpha=0.5, keep=2):
    print(f' --- CS_AGC  --- ')

    if keep == 1:
        # build predict_labels one hot
        predict_labels = to_onehot(predict_labels).T
    else:
        # taking top predict_labels one hot
        predict_labels = distance_to_centroids(kmeans_cands_agc, u, predict_labels, dist_type=1, keep=keep)

    # print(adjnormal.shape, np.count_nonzero(adjnormal), np.count_nonzero(adjnormal.dot(adjnormal.T)))

    # preparing for CS
    idx = mask_test
    y = to_onehot(gnd_lab).T
    y[idx, :] = 0

    # build adjnormal
    adjnormal = preprocess_adj(adjacency)

    # correct
    resu = predict_labels
    resu[idx, :] = 0
    resu = predict_labels - y
    for prop in range(correct_num):
        resu = (1 - lp_alpha) * resu + lp_alpha * adjnormal @ (resu ** p)

    resu = predict_labels + resu
    resu = np.clip(resu, 0, 1)

    # smooth
    for prop in range(smooth_num):
        resu = (1 - lp_alpha) * resu + lp_alpha * adjnormal @ (resu)
        # resu[train_mask, :] = y[train_mask, :]

    resu = np.clip(resu, 0, 1)
    resu = np.argmax(resu, axis=1)
    # print(resu.shape)

    accuracy_lp_agc = accuracy_score(gnd_lab, resu, normalize=True)
    f1_sc_lp_agc = f1_score(gnd_lab, resu, zero_division=1, average='micro')
    precision_lp_agc = precision_score(gnd_lab, resu, average='micro')
    print(' acc_CS_AGC: {}'.format(accuracy_lp_agc))

    return accuracy_lp_agc, f1_sc_lp_agc, precision_lp_agc