
from copy import deepcopy
from sklearn.metrics import accuracy_score, r2_score, f1_score, precision_score, recall_score
from torch_kmeans import KMeans
import ot
import torch
import torch.nn.functional as F
from model.decoder import NCDecoder
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score,silhouette_score,calinski_harabasz_score,davies_bouldin_score,mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, explained_variance_score

def evaluate_clustering(embeddings, true_labels, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)
    nmi = normalized_mutual_info_score(true_labels, cluster_labels)
    ari = adjusted_rand_score(true_labels, cluster_labels)
    ss = silhouette_score(embeddings, cluster_labels)
    chi = calinski_harabasz_score(embeddings, cluster_labels)
    dbi = davies_bouldin_score(embeddings, cluster_labels)
    return nmi, ari, ss, chi, dbi


class Embed_test:

    def __init__(self, args, ori_data, con_data):
        self.args = args
        self.ori_data = ori_data
        self.con_data = con_data
        self.dim = ori_data.x.shape[1]
        self.num_classes = ori_data.num_classes

    def model_train(self):
        ori_data =self.ori_data
        args = self.args

        # if 'hm' in args.dataset and  args.model == 'pgc':
        #     args.K = 0

        model_module = getattr(__import__('model', fromlist=[args.model]), args.model)
        model_class = getattr(model_module, args.model.upper())
        model = model_class(args, self.dim, self.num_classes, self.args.hidden_dim).cuda()


        # if ori_data.num_classes == 1 and self.original == False:
        #     a = torch.load(f'./weights/hm_class_s_{args.model.split("_")[0]}_nc_{args.gc_method}_{args.reduction_rate}_{args.seed}.pth', weights_only=True)
        #     filtered_state_dict = {k: v for k, v in a.items() if not k.startswith('output')}
        #     model.load_state_dict(filtered_state_dict, strict=False)
        #     # for name, param in model.named_parameters():
        #     #     if not name.startswith('output'):
        #     #         param.requires_grad = False


        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        best_result_val = -100
        epochs = args.epochs
        lr = args.lr

        input_graph = self.con_data
        self.y_train = self.con_data.y

        # 始终在原始图上验证和测试
        y_val = ori_data.y[ori_data.val_mask]
        y_test = ori_data.y[ori_data.test_mask]

        lis = []
        for i in range(epochs):
            if i == epochs // 2 and i > 0:
                lr = lr * 0.1
                optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            model.train()

            optimizer.zero_grad()

            output, _ = model.forward(input_graph)

            if ori_data.num_classes != 1:
                loss_train = torch.mean(torch.sum(self.y_train * -F.log_softmax(output, dim=1), dim=1))
            else:
                loss_train = F.mse_loss(output.squeeze(1), self.y_train)

            loss_train.backward()
            optimizer.step()

            with torch.no_grad():
                model.eval()
                output, ori_emb = model.forward(ori_data)

                if ori_data.num_classes != 1:
                    pred = output.max(1)[1]
                    pred = pred[ori_data.val_mask].cpu().numpy()
                    result_val = accuracy_score(y_val.cpu().numpy(), pred)
                else:
                    pred = output[ori_data.val_mask].squeeze().cpu().numpy() * ori_data.y_std.cpu().numpy() + ori_data.y_mean.cpu().numpy()
                    result_val = r2_score(y_val.cpu().numpy(), pred)

                if result_val > best_result_val:
                    best_result_val = result_val
                    best_epoch = i
                    weights = deepcopy(model.state_dict())


                if args.fine_tune == True and args.gc_method == 'pgc':
                    if (i+1) % args.fine_tune_delta == 0 and i < 601:
                        print('fine tuning ...')
                        sys_outpt, sys_emb = model.forward(input_graph)
                        h1 = ot.unif(ori_data.x.shape[0], type_as=ori_data.x)
                        h2 = ot.unif(input_graph.x.shape[0], type_as=input_graph.x)

                        if args.fine_tune_method == 'embedding:':
                            cost_emb = ot.dist(ori_emb, sys_emb, metric='euclidean')
                        else:
                            cost_emb = ot.dist(output, sys_outpt, metric='euclidean')
                        P = ot.emd(h1, h2, cost_emb, numItermax=500000)

                        # 回归
                        if self.num_classes == 1:
                            P_one_hot = torch.zeros_like(P).cuda()
                            P_one_hot[torch.arange(P.shape[0]), P.argmax(dim=1)] = 1.0

                            train_labels = ori_data.y_regre_std[ori_data.train_mask]

                            train_P = P_one_hot[ori_data.train_mask]
                            train_P = F.normalize(train_P, p=1, dim=0)
                            # column_sums = train_P.sum(dim=0)

                            s_emb_label = torch.mm(train_labels.unsqueeze(0), train_P)
                            s_emb_label = s_emb_label.squeeze(0)

                        else:
                            P_one_hot = torch.zeros_like(P).cuda()
                            P_one_hot[torch.arange(P.shape[0]), P.argmax(dim=1)] = 1.0

                            train_labels = ori_data.y[ori_data.train_mask]
                            one_hot_train_labels = F.one_hot(train_labels, num_classes=ori_data.num_classes).float().cuda()

                            one_hot_labels = torch.zeros(ori_data.num_nodes, ori_data.num_classes).cuda()
                            one_hot_labels[ori_data.train_mask] = one_hot_train_labels
                            # one_hot_labels = one_hot_train_labels

                            s_emb_label = torch.mm(P_one_hot.t(), one_hot_labels)
                            s_emb_label = F.normalize(s_emb_label.clamp(min=0), p=1, dim=1)


                        self.y_train = args.fine_tune_ratio*s_emb_label+(1-args.fine_tune_ratio)*self.y_train

        model.load_state_dict(weights)
        model.eval()
        output, embed = model.forward(ori_data)

        all_results = {}

        if ori_data.num_classes != 1:
            pred = output[ori_data.test_mask].max(1)[1].cpu().numpy()
            result_test = accuracy_score(y_test.cpu().numpy(), pred)
            macro_f1 = f1_score(y_test.cpu().numpy(), pred, average='macro')
            all_results['acc'] = result_test
            all_results['mac_f1'] = macro_f1
            print(f"Test set results: test_acc= {result_test:.5f}")
        else:
            pred = output[ori_data.test_mask].squeeze().detach().cpu().numpy() * ori_data.y_std.cpu().numpy() + ori_data.y_mean.cpu().numpy()
            y_test = y_test.cpu().numpy()
            result_test = r2_score(y_test, pred)
            mape = mean_absolute_percentage_error(y_test, pred)
            all_results['mape'] = mape
            explained_variance = explained_variance_score(y_test, pred)
            all_results['explained_variance'] = explained_variance
            all_results['r2'] = result_test
            print(f"Test set results: test_r2= {result_test:.5f}")

        return all_results
