import torch
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import tqdm
from scipy.sparse import csr_matrix
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from scipy.special import softmax
from models.gsl.base_learn import BaseLearn
import time
from entropy.partitionTree import PartitionTree, divide_community
from entropy.sample import get_community, get_layer_community, llm_reshape
from models.gnn.gcn import GCN
from models.gnn.gat import GAT
from models.gnn.graphsage import GraphSAGE
from models.utils import node_cls_train, node_cls_evaluate, cosine_sim, get_rankings
from llm.llmloaddata import get_nodescore
from transformers import AutoTokenizer, AutoModelForCausalLM

class LLaTALearn(BaseLearn):
    def __init__(self, logger, dataset, args, device):
        super(LLaTALearn, self).__init__()
        self.logger = logger
        self.args = args
        self.feat_dim = dataset.num_features
        self.output_dim = dataset.num_classes
        self.num_targets = dataset.num_targets
        self.num_nodes = dataset.num_node
        self.total_idx = dataset.total_idx
        self.class_name = dataset.class_name

        self.dataset = dataset
        self.labels = self.dataset.y.to(device)
        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.dropout = args.dropout
        self.K = self.args.LLaTA_K
        self.theta = self.args.LLaTA_theta
        self.r = self.args.LLaTA_r
        self.epsilon = self.args.LLaTA_epsilon
        self.nei_max = args.nei_max

        self.times = args.exp_times
        self.epochs = args.num_epochs
        self.early_stop = args.early_stop
        self.gnn_type = args.gnn_type
        self.device = device
        self.loss_fn = F.cross_entropy
        self.result = {'train': -1, 'valid': -1, 'test': -1}
        self.org_text = np.load(f'./datasets/{args.data_name}/raw/x_text.npy')
    def sym_adj(self):
        sym_adj = self.dataset.adj.todense()
        sym_adj = np.maximum(sym_adj, sym_adj.T)
        np.fill_diagonal(sym_adj, 0)
        return sym_adj
    
    def execute(self, save_adj=False):
        edge_index = torch.stack([self.dataset.edge.row, self.dataset.edge.col], dim=1)
        edge_index = torch.concat((edge_index, torch.flip(edge_index, dims=[1])), dim=0) # convert to undirected
        edge_index = torch.unique(edge_index, dim=0)
        print(edge_index.shape)
        adj_new = self.dataset.adj
        start_time = time.time()

        org_x = self.dataset.x.numpy()
        print('Building Coding Tree...')
        code_tree = PartitionTree(adj_new.todense())
        code_tree.build_coding_tree(self.K)
        node_dict = code_tree.tree_node
        community = get_layer_community(code_tree)
        m = 0

        llm_model = AutoModelForCausalLM.from_pretrained(
        self.args.llm_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True).to(self.device).eval()
        tokenizer = AutoTokenizer.from_pretrained(self.args.llm_path, trust_remote_code=True)

        llm_node_x = [[0 for j in range(self.dataset.num_classes)] for i in range(self.num_nodes)]
        train_idx = self.dataset.train_idx
        
        # LLM Inference
        for com in tqdm.tqdm(community):
            group = node_dict[com].partition
            group_size = len(group)
            # print(group_size)
            if group_size <= 3:
                parent_id = node_dict[com].parent
                for node in group:
                    # if node in train_idx:
                    #     llm_x = [0 for i in range(self.dataset.num_classes)]
                    #     llm_x[self.labels[node]] = 8
                    #     llm_node_x[node] = llm_x
                    #     continue
                    big_group = node_dict[parent_id].partition.copy()
                    big_group.remove(node)
                    nei_num = min(int(len(big_group)/2), self.nei_max)
                    nei_sim_list = []
                    for nei_node in big_group:
                        if node in train_idx and nei_node in train_idx and self.labels[node] == self.labels[nei_node]:
                            nei_sim_list.append(1)
                        else:
                            nei_sim_list.append(cosine_sim(org_x[nei_node], org_x[node]))
                    sim_idx = np.argsort(nei_sim_list)[::-1]
                    nei_list = []
                    for i in range(nei_num):
                        if nei_sim_list[sim_idx[i]] > self.epsilon:
                            nei_list.append(big_group[sim_idx[i]])
                    
                    answer = get_nodescore(llm_model, tokenizer, self.org_text[node], self.args.data_name, self.device, self.org_text[nei_list])
                    while len(answer) < self.dataset.num_classes:
                        answer.append(5)
                    llm_x = answer
                    llm_node_x[node] = llm_x
            else:
                for node in group:
                    # if node in train_idx:
                    #     llm_x = [0 for i in range(self.dataset.num_classes)]
                    #     llm_x[self.labels[node]] = 8
                    #     llm_node_x[node] = llm_x
                    #     continue
                    big_group = group.copy()
                    big_group.remove(node)
                    nei_num = min(int(len(big_group)/2), self.nei_max)
                    nei_sim_list = []
                    for nei_node in big_group:
                        if node in train_idx and nei_node in train_idx and self.labels[node] == self.labels[nei_node]:
                            nei_sim_list.append(1)
                        else:
                            nei_sim_list.append(cosine_sim(org_x[nei_node], org_x[node]))
                    sim_idx = np.argsort(nei_sim_list)[::-1]
                    nei_list = []
                    for i in range(nei_num):
                        if nei_sim_list[sim_idx[i]] > self.epsilon:
                            nei_list.append(big_group[sim_idx[i]])
                    
                    answer = get_nodescore(llm_model, tokenizer, self.org_text[node], self.args.data_name, self.device, self.org_text[nei_list])
                    while len(answer) < self.dataset.num_classes:
                        answer.append(5)
                    llm_x = answer
                    llm_node_x[node] = llm_x
                    print(llm_x, softmax(llm_x))
        
        # exception handling
        for node in range(self.num_nodes):
            if len(llm_node_x[node]) > self.dataset.num_classes:
                llm_node_x[node] = llm_node_x[node][:self.dataset.num_classes]
                continue
            if len(llm_node_x[node]) < self.dataset.num_classes:
                llm_node_x[node] = [5 for i in range(self.dataset.num_classes)]
                
        # leaf allocation
        m = 0
        for node in range(self.num_nodes):
            if all(x == 5 for x in llm_node_x[node]):
                m += 1
        print("Uncertainty number:", m)
        llm_node_x = softmax(llm_node_x, axis=1)    
        for com in community:
            group = node_dict[com].partition
            group_size = len(group)

            if group_size <= 3:
                continue
                    
            llm_group_x = []
            for i, node in enumerate(group):
                llm_group_x.append(llm_node_x[node])
            silhouette_max = -1
            cluster_labels = []
            for k in range(2, int(group_size/2+1)):
                kmeans = KMeans(n_clusters=k, random_state=self.args.seed)
                kmeans.fit(llm_group_x)
                cluster_labels = kmeans.labels_
                if all (x == 0 for x in cluster_labels):
                    break
                silhouette_avg = silhouette_score(llm_group_x, cluster_labels)
                # print(k, silhouette_avg)
                if silhouette_avg - silhouette_max > 0.01:
                    silhouette_max = silhouette_avg
                else:
                    break 
            # print(silhouette_max)
            if silhouette_max < 0.3:
                continue
            subgroup_list = [[] for i in range(k)]
            for i, node in enumerate(group):
                cluster_id = cluster_labels[i]
                subgroup_list[cluster_id].append(node)
            
            divide_community(code_tree, com, subgroup_list)
                
        print('Subdivide Finished')

        community, isleaf = get_community(code_tree)
        llm_com_x = {}
        for com in community:
            node_list = node_dict[com].partition
            node_x_list = llm_node_x[node_list]
            # llm_com_x[com] = np.mean(node_x_list, axis=0)
            llm_com_x[com] = []
            llm_com_x[com].append(np.mean(node_x_list, axis=0))
            node_x_list = org_x[node_list]
            llm_com_x[com].append(np.mean(node_x_list, axis=0))
        
        new_edge_index = llm_reshape(community, code_tree, isleaf, self.r, llm_node_x, org_x, llm_com_x, self.theta)
        new_edge_index = torch.concat((new_edge_index, edge_index.cpu()), dim=0)
        new_edge_index = torch.concat((new_edge_index, torch.flip(new_edge_index, dims=[1])), dim=0)
        new_edge_index = torch.unique(new_edge_index, dim=0)
        new_edge_index = new_edge_index.t()
        row, col = new_edge_index.numpy()
        num_edges = new_edge_index.shape[1]
        print(f'edge num: {num_edges}')
        data = np.ones(num_edges)
        adj_new = csr_matrix((data, (row, col)), shape=(self.num_nodes, self.num_nodes))

        if self.gnn_type == "gcn":
            gnn_model = GCN(feat_dim=self.feat_dim, hidden_dim=self.args.hidden_dim, 
                            output_dim=self.output_dim, dropout=self.dropout, task_level='node').to(self.device)
        elif self.gnn_type == "gat":
            gnn_model = GAT(feat_dim=self.feat_dim, hidden_dim=self.args.hidden_dim, num_heads=self.args.num_heads,
                            output_dim=self.output_dim, dropout=self.dropout, task_level='node').to(self.device)
        elif self.gnn_type == "graphsage":
            gnn_model = GraphSAGE(feat_dim=self.feat_dim, hidden_dim=self.args.hidden_dim, 
                            output_dim=self.output_dim, dropout=self.dropout, task_level='node').to(self.device)
            
        self.logger.info("gnn train:")
        if self.gnn_type == 'graphsage':
            gnn_model.preprocess(new_edge_index, self.dataset.x)
        else:
            gnn_model.preprocess(adj_new, self.dataset.x)
        self.gnn_optimizer = optim.Adam(gnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_val = 0.
        best_test = 0.
        stop = 0
        total_time = 0
        for epoch in range(self.epochs):
            if stop > self.early_stop:
                print("Early stop!")
                break

            loss_train, acc_train = node_cls_train(gnn_model, self.dataset.train_idx, self.labels, self.device,
                                                        self.gnn_optimizer, self.loss_fn)
            acc_val, acc_test = node_cls_evaluate(gnn_model, self.dataset.val_idx, self.dataset.test_idx,
                                                        self.labels, self.device)
            # print("Epoch: {:03d}, loss_train: {:.4f}, acc_train: {:.2f}, acc_val: {:.2f}, "
            #                  "acc_test: {:.2f}".format(epoch + 1, loss_train, acc_train, acc_val, acc_test))

            if acc_val > best_val:
                best_val = acc_val
                best_test = acc_test
                best_train = acc_train
                stop = 0
            stop += 1
        total_time = time.time() - start_time
        self.logger.info("Optimization Finished!")
        self.logger.info(f'Best val: {best_val:.2f}, best test: {best_test:.2f}, time: {total_time}')
        
        if save_adj:
            np.save(f'./datasets/{self.args.data_name}/processed/selm_adj.npy', adj_new.todense())

        return best_test, total_time