import numpy as np
import networkx as nx
import torch
from torch_geometric.utils import degree, to_dense_adj, dense_to_sparse
import torch.nn.functional as F
import copy
from sklearn.manifold import MDS

import os
import sys
par_dir = os.path.abspath("../")
sys.path.append(par_dir)
from LinearFGW.dataio import k_barycenter_load
from LinearFGW.algot1 import *
import time

import numpy as np


def preprocess(data_in):
    dataset = copy.deepcopy(data_in)
    is_plain = dataset[0].x is None
    
    ## convert graph label into one-hot
    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)
    # for data in dataset:
    #     data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0].view(-1, num_classes)
    if is_plain:    # use node degree as attributes
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max( max_degree, degs[-1].max().item() )

        if max_degree < 2000:
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    
    for data in dataset:
        adj = to_dense_adj(data.edge_index)
        _, edge_weight = dense_to_sparse(adj)
        data.edge_weight = edge_weight
        data.num_nodes = len(adj.squeeze())
        if data.edge_attr is not None:
            data.edge_attr = None
                        
    return dataset


def prepare_graph_info(dataset, index, device, normalize=False):

    output = []
    for i in range(len(index)):

        # get node id
        graph_idx = index[i]
        # get # nodes in the graph
        node_num = dataset[graph_idx].num_nodes
        # get central node label
        node_label = dataset[graph_idx].y.to(device)
        # get feature matrix of the graph
        if normalize: 
            feature = torch.nn.functional.normalize(dataset[graph_idx].x).to(device)
        else:
            feature = dataset[graph_idx].x.to(device)

        # get structure 
        values = torch.ones(dataset[graph_idx].edge_index.shape[1])
        structure = torch.sparse_coo_tensor(dataset[graph_idx].to(device).edge_index, values.to(device), (node_num, node_num))
        structure = structure.to_dense()

        # get p (uniform)
        p = torch.ones(node_num) / node_num
        p = p.to(device)

        # insert
        output.append([structure.float(), p.unsqueeze(1), feature.float(), node_label.unsqueeze(0)])

    return output



def compute_pairwise_distance_with_emb(graph_format_set1, graph_format_set2, graph_name, center_graph_size, alphas, device, force_recompute=True):
    # compute pairwise distance between two sets

    # alpha: ratio for graph feature
    # 1 - alpha: ratio for graph structure

    # define hyper-params
    ot_method ="ppa"
    gamma = 0.1
    gwb_layers = 3
    ot_layers = 3
    n_iters = 3


    dim_embedding = graph_format_set1[0][2].shape[1]

    filename = os.path.join(f'../barycenters/k_barycenter_{graph_name}.pkl')
    # os.makedirs(filename, exist_ok=True)

    output = {}
    emb_output = {}
    for alpha in alphas:
        print("pairwise distance computation with alpha: ", alpha)
        time1 = time.time()
        # check whether the k-barycenter file is there
        if os.path.exists(filename) and not force_recompute:
            # if yes, then load the pre-stored files
            centers, ps, embeddings = k_barycenter_load(filename)
        else:
            # if no, then we re-compute the k-barycenters
            centers, ps, embeddings = k_barycenter(graph_format_set1+graph_format_set2, center_graph_size, 
                                                    dim_embedding, ot_method, gamma, gwb_layers, 
                                                    ot_layers, n_iters, filename, alpha, device)

        time2 = time.time()
        print("k barycenter time: ", time2-time1)
        centers = [torch.tensor(center).type(torch.FloatTensor) for center in centers]
        ps = [torch.tensor(p).type(torch.FloatTensor) for p in ps]
        if embeddings[0] is not None:
            embeddings = [torch.tensor(embedding).type(torch.FloatTensor) for embedding in embeddings]

        G2Fs_train = fusedGW_featurize(graph_format_set1, centers, ps, embeddings, alpha, device)
        G2Fs_val = fusedGW_featurize(graph_format_set2, centers, ps, embeddings, alpha, device)
        time3 = time.time()
        print("featurize time: ", time3-time2)
        # the total number of centers times (GW feature dimension + W feature dimension )
        emb_size = sum(center_graph_size) * (G2Fs_train[0][0][0].shape[1] + G2Fs_train[0][1][0].shape[1])
        emb_matrix_train = torch.zeros(len(G2Fs_train), emb_size)
        emb_matrix_val = torch.zeros(len(G2Fs_val), emb_size)
        # print("emb matrix shape: ", emb_matrix.shape)
        for i in range(len(G2Fs_train)):
            W_feats, GW_feats, _ = G2Fs_train[i]
            concat_feat = torch.cat([np.sqrt(alpha) * W_feats[0].flatten(), np.sqrt(1-alpha) * GW_feats[0].flatten()], dim=0)
            emb_matrix_train[i] = concat_feat

        for i in range(len(G2Fs_val)):
            W_feats, GW_feats, _ = G2Fs_val[i]
            concat_feat = torch.cat([np.sqrt(alpha) * W_feats[0].flatten(), np.sqrt(1-alpha) * GW_feats[0].flatten()], dim=0)
            emb_matrix_val[i] = concat_feat

        time4 = time.time()
        print("embedding computation time: ", time4-time3)

        A_squared = torch.sum(emb_matrix_train**2, dim=1, keepdim=True)  # Shape: (5, 1)
        B_squared = torch.sum(emb_matrix_val**2, dim=1, keepdim=True).T  # Shape: (1, 4)
        dot_product = torch.matmul(emb_matrix_train, emb_matrix_val.T)  # Shape: (5, 4)
        # pairwise_squared_distances = A_squared + B_squared - 2 * dot_product
        output[alpha] = A_squared + B_squared - 2 * dot_product
        emb_output[alpha] = (emb_matrix_train, emb_matrix_val)
        time5 = time.time()
        print("distance matrix for-loop time: ", time5-time4)
        # output[alpha] = train_val_distance
        # emb_output[alpha] = (emb_matrix_train, emb_matrix_val)
    return output, emb_output



def compute_pairwise_label_distance(pairwise_distance_dict, label_dict_train, label_dict_val):


    output = {}
    # for each alpha we consider during pairwise distance computation
    for alpha in pairwise_distance_dict.keys():
        # print("cur alpha value")
        per_alpha_dict = {}
        for train_label in label_dict_train:
            train_data_idx = torch.tensor(label_dict_train[train_label])
            # print("cur train label: ", train_label)
            # print("cur train label size: ", len(train_data_idx))
            for val_label in label_dict_val:
                
                val_data_idx = torch.tensor(label_dict_val[val_label])
                # print("cur val label: ", val_label)
                # print("cur val label size: ", len(val_data_idx))

                # if not empty
                if len(train_data_idx) and len(val_data_idx):
                    # get cost matrix and compute OT distance between two sets
                    # print("pairwise distance matrix shape: ", pairwise_distance_dict[alpha].shape)
                    M = pairwise_distance_dict[alpha][torch.meshgrid(train_data_idx, val_data_idx, indexing='ij')]
 
                    # compute p, q here
                    # uniform distribution for now
                    train_per_label_size = M.shape[0]
                    val_per_label_size = M.shape[1]

                    p = torch.ones(train_per_label_size) / train_per_label_size
                    q = torch.ones(val_per_label_size) / val_per_label_size

                    # ot.tic()
                    W0, _ = ot.emd2(p, q, M, log=True)
                    # ot.toc()
                    
                    per_alpha_dict[(train_label, val_label)] = W0.item()
        
        output[alpha] = per_alpha_dict

    return output


def compute_label_aware_embedding(original_emb_dict, pairwise_label_distance_dict, label_dict_train):

    label_num = len(label_dict_train.keys())
    mds_machine = MDS(n_components=label_num, dissimilarity='precomputed', normalized_stress='auto')


    label_aware_emb_dict = {}
    # for each alpha we consider during pairwise distance computation
    for alpha in original_emb_dict.keys():

        label_dist = torch.zeros(label_num, label_num)
        for label_i in range(label_num):
            for label_j in range(label_num):
                if label_i != label_j :
                    label_dist[label_i][label_j] = pairwise_label_distance_dict[alpha][(label_i, label_j)]

        # print(label_dist)



        label_emb = torch.from_numpy(mds_machine.fit_transform(label_dist)).float()

        all_label_emb = torch.zeros(original_emb_dict[alpha].shape[0], label_emb.shape[1])
        # print(all_label_emb.dtype)
        # insert label embedding to graph embedding
        for train_label in label_dict_train:
            train_data_idx = torch.tensor(label_dict_train[train_label])
            # print(train_data_idx)
            all_label_emb[train_data_idx] = label_emb[train_label]
            

        label_aware_emb_dict[alpha] = torch.cat([original_emb_dict[alpha], all_label_emb], dim=1)
        # print(label_aware_emb_dict[alpha].shape)
        
                
        # label_aware_emb_dict[alpha] = label_aware_emb_dict

    return label_aware_emb_dict


def aug_emb(pairwise_label_distance_dict):
    label_embedding_dict = {}
    for alpha in pairwise_label_distance_dict:
        # Extract unique indices and sort them
        indices = sorted(set(i for pair in pairwise_label_distance_dict[alpha].keys() for i in pair))

        # Create a distance matrix
        n = len(indices)
        distance_matrix = np.zeros((n, n))
        for (i, j), dist in pairwise_label_distance_dict[alpha].items():
            distance_matrix[i, j] = dist
            distance_matrix[j, i] = dist  # Ensure symmetry
        
        mds = MDS(n_components=2, random_state=42, dissimilarity='precomputed')
        low_dim_embedding = mds.fit_transform(distance_matrix)
        # print(low_dim_embedding.shape)
        label_embedding_dict[alpha] = low_dim_embedding
    return label_embedding_dict

def compute_new_train_node_weight_ver1(pairwise_distances_dict, pairwise_label_distance_dict, 
                                  label_signal, train_index, val_index,
                                  label_list, learning_rate=1e-4, update_step=40, sparsity_ratio=(0.5, 0.5)):
    
    new_train_node_weight = {}

    for alpha in pairwise_distances_dict.keys():

        train_graph_num, val_graph_num = pairwise_distances_dict[alpha].shape
        # initial train node weight is uniform distribution
        train_node_weight = torch.ones(train_graph_num) / train_graph_num

        label_corrected_pairwise_distance = copy.deepcopy(pairwise_distances_dict[alpha])

        time_start = time.time()


        train_labels = np.array([label_list[train_index[i]] for i in range(train_graph_num)])
        val_labels = np.array([label_list[val_index[j]] for j in range(val_graph_num)])
        distance_matrix = np.vectorize(
            lambda x, y: label_signal * pairwise_label_distance_dict[alpha][(x, y)]
        )(train_labels[:, None], val_labels[None, :])

        label_corrected_pairwise_distance += distance_matrix

        time_end = time.time()
        # print("GDD time: ", time_end - time_start)
        
        output_list = []
        for step in range(update_step):
            # EMD distance computation for label-aware pairwise distance
            # get p, q
            # p will be the current train node weight
            p = train_node_weight
            # q will be fixed to be uniform distribution
            q = torch.ones(val_graph_num) / val_graph_num


            # ot.tic()
            OTDD, logD = ot.emd2(p, q, label_corrected_pairwise_distance, log=True)
            # ot.toc()

            # get train node weight gradient
            # mu_grad = (train_graph_num * 1.0 / (train_graph_num - 1) ) * logD['u'] - (1.0 / (train_graph_num-1)) * torch.sum(logD['u'])
            
            sparsity_num = int((sparsity_ratio[0] + (sparsity_ratio[1] - sparsity_ratio[0]) * step / (update_step - 1)) * train_graph_num)
            _learning_rate = torch.where(logD['u'] > 0, train_node_weight / logD['u'], 
                                         torch.where(logD['u'] < 0, (train_node_weight - 1.) / logD['u'], 
                                         torch.tensor(learning_rate, dtype = train_node_weight.dtype, 
                                                     device = train_node_weight.device))).min()
            _learning_rate = min(_learning_rate.item(), learning_rate)
            train_node_weight = train_node_weight - _learning_rate * logD['u']
            if train_node_weight.min() < 0:
                train_node_weight -= train_node_weight.min()
            train_node_weight[torch.topk(-train_node_weight, k = train_graph_num - sparsity_num).indices] = 0.
            train_node_weight = train_node_weight / train_node_weight.sum()

        new_train_node_weight[alpha] = train_node_weight.nonzero().squeeze(1)

    return new_train_node_weight



def compute_new_train_node_weight_lava(pairwise_distances_dict, pairwise_label_distance_dict, 
                                  label_signal, train_index, val_index,
                                  label_list, selection_ratio):
    
    new_train_node_weight = {}

    for alpha in pairwise_distances_dict.keys():

        train_graph_num, val_graph_num = pairwise_distances_dict[alpha].shape
        # initial train node weight is uniform distribution
        train_node_weight = torch.ones(train_graph_num) / train_graph_num

        label_corrected_pairwise_distance = copy.deepcopy(pairwise_distances_dict[alpha])

        train_labels = np.array([label_list[train_index[i]] for i in range(train_graph_num)])
        val_labels = np.array([label_list[val_index[j]] for j in range(val_graph_num)])
        distance_matrix = np.vectorize(
            lambda x, y: label_signal * pairwise_label_distance_dict[alpha][(x, y)]
        )(train_labels[:, None], val_labels[None, :])

        label_corrected_pairwise_distance += distance_matrix
        
        # output_list = []
        
        # EMD distance computation for label-aware pairwise distance
        # get p, q
        # p will be the current train node weight
        p = train_node_weight
        # q will be fixed to be uniform distribution
        q = torch.ones(val_graph_num) / val_graph_num


        # ot.tic()
        OTDD, logD = ot.emd2(p, q, label_corrected_pairwise_distance, log=True)
        # ot.toc()

        calibrate_grad = (train_graph_num * 1.0 / (train_graph_num - 1) ) * logD['u'] - \
              (1.0 / (train_graph_num-1)) * torch.sum(logD['u'])
            

        k = int(selection_ratio * len(calibrate_grad)) 
        _, indices = torch.topk(calibrate_grad, k, largest=False)
        topk_indices = indices[:k]
        # new node weight is uniform on the top-k indices
        new_node_weight = torch.zeros(len(train_index))
        new_node_weight[topk_indices] = 1.0
        new_node_weight /= new_node_weight.sum()

        new_train_node_weight[alpha] = topk_indices

    return new_train_node_weight




def our_selection_method_wrap(args, dataset, 
                              dataset_name, label_dict, 
                              label_list, train_idx, val_idx):
    

    # get train id -> train idx map
    train_id_idx_mapping = {}
    for k in range(len(train_idx)):
        train_id_idx_mapping[train_idx[k].item()] = k
    # get val id -> val idx map
    val_id_idx_mapping = {}
    for k in range(len(val_idx)):
        val_id_idx_mapping[val_idx[k].item()] = k


    # gather train, val graph information for FGW computation
    if args.dataset_name in ['ENZYMES', 'PROTEINS', 'ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
        train_graph_format = prepare_graph_info(dataset, train_idx, args.linearfgw_device, normalize=True)
        val_graph_format = prepare_graph_info(dataset, val_idx, args.linearfgw_device, normalize=True)
    
    else:
        train_graph_format = prepare_graph_info(dataset, train_idx, args.linearfgw_device)
        val_graph_format = prepare_graph_info(dataset, val_idx, args.linearfgw_device)
    

    pairwise_loc ="pairwise_distance_store/" 
    os.makedirs(pairwise_loc, exist_ok=True)
    storing_file_name = pairwise_loc+ f"pairwise_info_{dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}.pt"
    if os.path.exists(storing_file_name):
        print("distance/emb exists...")
        pairwise_distances_dict, emb_matrix = torch.load(storing_file_name)
    else:
        print("distance/emb not exists...")
        pairwise_distances_dict, emb_matrix = compute_pairwise_distance_with_emb(train_graph_format, 
                                                                                val_graph_format, 
                                                                                dataset_name, 
                                                                                args.barycenter_sizes, 
                                                                                args.alphas, 
                                                                                args.linearfgw_device)
        torch.save((pairwise_distances_dict, emb_matrix), storing_file_name)



    # compute pairwise label-distance
    label_dict_train = {}
    for train_label in label_dict:
        train_data_with_this_label = list(set(label_dict[train_label]).intersection(train_idx.tolist()))
        label_dict_train[train_label] = [train_id_idx_mapping[train_data_with_this_label[i]] for i in range(len(train_data_with_this_label))]

    label_dict_val = {}
    for val_label in label_dict:
        val_data_with_this_label = list(set(label_dict[val_label]).intersection(val_idx.tolist()))
        label_dict_val[val_label] = [val_id_idx_mapping[val_data_with_this_label[i]] for i in range(len(val_data_with_this_label))]

    # compute pairwise label-distance
    pairwise_label_distance_dict = compute_pairwise_label_distance(pairwise_distances_dict, 
                                                                   label_dict_train, 
                                                                   label_dict_val)
    
    if args.our_method_type == 'kmed':
        label_embedding_dict = aug_emb(pairwise_label_distance_dict)
        train_label_list = torch.tensor(label_list)[train_idx]
        val_label_list = torch.tensor(label_list)[val_idx]
        aug_emb_dict = {}
        for alpha in args.alphas:
            aug_emb_dict[alpha] = {}
            for label_signal in args.label_signals:

                label_embeddings_train = label_embedding_dict[alpha][train_label_list]
                label_embeddings_val = label_embedding_dict[alpha][val_label_list]

                aug_train_emb = np.hstack([emb_matrix[alpha][0], label_signal * label_embeddings_train])
                aug_val_emb = np.hstack([emb_matrix[alpha][1], label_signal * label_embeddings_val])
                
                aug_emb_dict[alpha][label_signal] = (aug_train_emb, aug_val_emb)


    # store the new train weight (index)
    new_train_weight_dict = {}

    for label_signal in args.label_signals:
        new_train_weight_dict[label_signal] = {}

        for selection_ratio in args.selection_ratios:
            new_train_weight_dict[label_signal][selection_ratio] = {}
            print(f"currently select: label signal {label_signal}, {selection_ratio}")
            # for selection
            if args.our_method_type == 'ver1':
            
                new_train_weight_per_ratio = compute_new_train_node_weight_ver1(pairwise_distances_dict, 
                                                                                pairwise_label_distance_dict, 
                                                                                label_signal,
                                                                                train_idx, 
                                                                                val_idx, 
                                                                                label_list, 
                                                                                update_step=args.update_steps, 
                                                                                sparsity_ratio=(1.0, selection_ratio))
                

            elif args.our_method_type == 'lava':
                new_train_weight_per_ratio = compute_new_train_node_weight_lava(pairwise_distances_dict, 
                                                                                pairwise_label_distance_dict, 
                                                                                label_signal,
                                                                                train_idx, 
                                                                                val_idx, 
                                                                                label_list,
                                                                                selection_ratio) 
                

            if args.our_method_type in ['ver1', 'lava']:
                # for each alpha in alphas
                for alpha_value in new_train_weight_per_ratio:
                    print(f"new weight computation with alpha: {alpha_value}, label signal: {label_signal}, select_ratio: {selection_ratio}")
                    # get non-zero-weighted index in the new train weight
                    # selected_new_train_idx = torch.tensor(new_train_weight_per_ratio[alpha_value].nonzero().tolist()).squeeze(0).squeeze(1)
                    selected_new_train_idx = torch.tensor(new_train_weight_per_ratio[alpha_value])
                    # append to dict
                    new_train_weight_dict[label_signal][selection_ratio][alpha_value] = selected_new_train_idx


    return new_train_weight_dict

            



