import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from net import GraphAttentionLayer
from utils import sparse_to_adj

PWD = os.path.dirname(os.path.realpath(__file__)) 

class KnowledgeGenerator(nn.Module):
    def __init__(self, args, num_classes, num_features, node_size, interLayer_edgeMat, layer_edgeMat, device, logger):
        super(KnowledgeGenerator,self).__init__()
        self.args = args
        self.num_classes = num_classes
        self.num_features = num_features
        self.node_size = node_size
        self.height = len(self.node_size)
        self.interLayer_edgeMat = interLayer_edgeMat
        self.layer_edgeMat = layer_edgeMat
        self.device = device
        self.logger = logger

        self.CalInter = GraphAttentionLayer(in_features=self.num_features, out_features=self.num_classes).to(self.device)          
        self.CalIntra = GraphAttentionLayer(in_features=self.num_features, out_features=self.num_classes).to(self.device)        

        self.t_numfeat = sum(self.node_size[1:])
        self.s_numfeat = self.num_features * self.args.childbd 
        self.MLP_t = nn.Linear(self.t_numfeat, self.t_numfeat)
        self.MLP_s = nn.Linear(self.s_numfeat, self.num_features)
        self.MLP_rw = nn.Linear(self.num_features * (self.height + 1), self.num_features)

        self.P_rw = self.args.P_rw / (self.args.P_rw + self.args.S_rw + self.args.C_rw)
        self.S_rw = self.args.S_rw / (self.args.P_rw + self.args.S_rw + self.args.C_rw)
        self.C_rw = self.args.C_rw / (self.args.P_rw + self.args.S_rw + self.args.C_rw)

        self.teacher_entropy = self.args.teacher_entropy

        self.all_children_list, self.selected_backup = [], []  
        for layer in range(1, self.height):
            nodenum = self.node_size[layer] 
            interLayer_edge = self.interLayer_edgeMat[layer] 
           
            children_list = []
            for node in range(nodenum): 
                children = interLayer_edge[1, interLayer_edge[0] == node] 
                children_list.append(children.tolist()) 
            self.all_children_list.append(children_list)

            for node in range(nodenum):
                selected = self.rand_select(node, children_list, nodenum) 
                self.selected_backup.append(selected)  

    def forward(self, data_x, labels, train_idx, node_size):
        self.features = [data_x.unsqueeze(0)]
        self.selected_index = 0
        feature_list_s = []
        feature_t = torch.tensor([[]]).to(self.device)
        
        for layer in range(1, self.height): 

            nodenum = node_size[layer]                        
            layer_edge = self.layer_edgeMat[layer-1]          
            layer_adj = sparse_to_adj(layer_edge, self.device) 
            layer_feature = self.features[layer-1].squeeze()   

            self.children_list = self.all_children_list[layer - 1]    

            feature_list_layer = []
            for node in range(nodenum):

                node_children = self.children_list[node]             
                intra_children_feature = layer_feature[node_children]
                intra_adj = layer_adj[node_children][:,node_children]
  
                alpha_intra = self.CalIntra(intra_children_feature, intra_adj)
                feature_intra = torch.sum(intra_children_feature * alpha_intra.view(-1, 1), dim=0).squeeze(0) # 加权求和

                feat_s = intra_children_feature.view(-1) 

                if nodenum > 1:
                    alpha_inter, selected_flat = self.cal_alpha_inter(layer_feature, layer_adj)
                    # print(f"alpha_inter: {alpha_inter}")
                    inter_children_feature = layer_feature[selected_flat]        
                    feature_inter = torch.sum(inter_children_feature * alpha_inter.view(-1, 1), dim=0).squeeze(0)
                    node_feature_tensor = (feature_inter + feature_intra)
                    
                    feat_s = torch.cat((feat_s, inter_children_feature.view(-1))) 
                    
                else:
                    node_feature_tensor = feature_intra
                
                pad_length = max(0, self.s_numfeat - len(feat_s)) 
                padded_feat_s = F.pad(feat_s, (0, pad_length), value=0)
                padded_feat_s = padded_feat_s[:self.s_numfeat].tolist()

                if feature_t.numel() == 0:
                    feature_t = torch.cat((feature_t, node_feature_tensor.unsqueeze(0)), dim=1)
                else:
                    feature_t = torch.cat((feature_t, node_feature_tensor.unsqueeze(0)), dim=0)
                
                feature_list_s.append(padded_feat_s)
                feature_list_layer.append(node_feature_tensor.tolist()) 
            
            self.features.append(torch.Tensor(feature_list_layer).to(self.device))
        
        feature_s = torch.Tensor(feature_list_s).to(self.device) 

        T = torch.sigmoid(self.MLP_t(self.cal_entropy(feature_t)))
        T = torch.reshape(T*(self.teacher_entropy-1) + 1, (T.size()[0],1)) 

        self.rw_feature = self.rand_walk(0)
        leaf_node = self.node_size[0]
        for i in range(1, leaf_node):
            tmp_feat = self.rand_walk(i)
            self.rw_feature = torch.concat((self.rw_feature, tmp_feat), dim=0)
            
        loss_kd = self.KD_loss_with_temp(self.MLP_s(feature_s), feature_t, T)

        MLP_rw_feature = self.MLP_rw(self.rw_feature)
        loss_rw = F.cross_entropy(MLP_rw_feature[train_idx], labels[train_idx])

        loss = loss_kd + loss_rw 

        self.logger.info(f"loss_kd: {loss_kd}, loss_rw: {loss_rw}")
        self.logger.info(f"loss: {loss}")
        return loss, MLP_rw_feature

    def rand_walk(self, cur_node):
        cur_layer = 0
        rw_feature = (self.features[0].squeeze())[cur_node].squeeze()
        for _ in range(self.height):
            if cur_layer > 0:
                probabilities = [self.P_rw, 1 - self.P_rw - self.C_rw, self.C_rw]
                events = [cur_layer + 1, cur_layer, cur_layer - 1]
            elif cur_layer == 0:
                probabilities = [self.P_rw, 1 - self.P_rw]
                events = [cur_layer + 1, cur_layer]
            else:
                raise ValueError("cur_layer cannot be negative")
            
            rw_layer = np.random.choice(events, p=probabilities)
            layer_feature = self.features[rw_layer].squeeze()

            if rw_layer == cur_layer: # sybling
                sybling_node = random.randint(0,layer_feature.size(1)-1)
                rw_feat = layer_feature[sybling_node].squeeze()

            elif rw_layer < cur_layer: # children
                children_list = self.layer_children_list[cur_layer-1][cur_node]
                child_node = random.choice(children_list)
                rw_feat = layer_feature[child_node].squeeze()

            elif rw_layer > cur_layer: #parent
                interlayer = self.interLayer_edgeMat[rw_layer].tolist()
                index_p = interlayer[1].index(cur_node)
                parent_node = interlayer[0][index_p]
                rw_feat = layer_feature[parent_node].squeeze()
            rw_feature = torch.concat((rw_feature,rw_feat),dim=0)
            
        return rw_feature.unsqueeze(0)

    def rand_select(self, cur_node, children_list, nodenum): 
        selected = []
        select_num = min(len(children_list[cur_node]), sum(len(sublist) for sublist in children_list)-len(children_list[cur_node]))
        select_times = [0]*(nodenum - 1)
        
        index, t = 0, 0
        while(select_num):
            if t == cur_node: t = (t + 1) % nodenum

            if len(children_list[t]) > select_times[index]:
                select_times[index] += 1
                index =  (index + 1) % (nodenum - 1)
                t = (t + 1) % nodenum
                select_num -= 1
            else:
                index =  (index + 1) % (nodenum - 1)
                t = (t + 1) % nodenum
        
        t = 0
        for stime in select_times:
            if t == cur_node: t += 1
            tmp = random.sample(children_list[t],stime)
            selected.append(tmp)
            t += 1

        return selected
    
    def cal_alpha_inter(self, layer_feature, layer_adj):

        selected = self.selected_backup[self.selected_index]
        self.selected_index += 1
        

        selected_flat = [elem for sublist in selected for elem in sublist]
        other_intra_children_feature = layer_feature[selected_flat]  
        other_intra_adj = layer_adj[selected_flat][:, selected_flat]

        alpha_inter = self.CalInter(other_intra_children_feature, other_intra_adj)


        # alpha_inter_list, selected_flat = [], []
        # index = 0
        # alpha_inter = torch.tensor([]).to(self.device)
        # # print("nodenum: ", nodenum)
        # # current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        # for other in range(nodenum):
        #     if other == cur_node: continue
            
        #     other_node_children = self.children_list[other]
        #     other_intra_children_feature = layer_feature[other_node_children]  
        #     print("other_intra_children_feature: ", other_intra_children_feature)
        #     other_intra_adj = layer_adj[other_node_children][:, other_node_children]
        #     other_alpha_intra = self.CalIntra(other_intra_children_feature, other_intra_adj)
        #     print("other_alpha_intra: ", other_alpha_intra)
            
        #     inter_node_children = self.children_list[other] + self.children_list[cur_node]
        #     inter_feature = layer_feature[inter_node_children]
        #     inter_adj = layer_adj[inter_node_children][:, inter_node_children]
        #     other_alpha_inter = self.CalInter(inter_feature, inter_adj)

        #     s_node = selected[index]
        #     selected_flat.extend(s_node) 
        #     index += 1
        #     s_indices = [other_node_children.index(element) for element in s_node]
 
        #     alpha_inter_element = []
        #     for indice in s_indices:
        #         alpha_inter_element.append(max(other_alpha_inter[indice], other_alpha_intra[indice]).item())
        #     alpha_inter_list.extend(alpha_inter_element)
        
        # alpha_inter = torch.Tensor(alpha_inter_list).to(self.device)
        return alpha_inter, selected_flat
    
    def cal_entropy(self, result):
        probs = F.softmax(result, dim=1)
        return F.cross_entropy(probs, probs, reduction='none')

    def KD_loss_with_temp(self, y_s, y_t, T):
        
        p_s = F.softmax(y_s, dim=1)
        p_t = F.softmax(y_t / T, dim=1)
     
        loss = F.cross_entropy(p_s, p_t) 
        # loss = F.mse_loss(p_s, p_t)
        return loss

    def evaluate(self, mask_idx, labels):
        s = self.MLP_rw(self.rw_feature)
        q_result = s[mask_idx].log_softmax(dim=-1)
        true_result = labels[mask_idx].squeeze()
        
        _, indices = torch.max(q_result, dim=1)
        correct = torch.sum(indices == true_result)
        accs = correct.item() * 1.0 / len(true_result)
        return accs

