import torch
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader
import networkx as nx
import pickle
import os, sys
sys.path.append(os.path.abspath(os.getcwd()))

from utils.graph_utils import node_flags, graphs_to_tensor, mask_x
from utils.node_features import NodeCycleFeatures, EigenFeatures


# -------- Create initial node features --------
def init_features(feat_dict, adjs, nfeat=10):
    flags = node_flags(adjs)
    feature = []
    feat_dim = []
    for feat_type in feat_dict.type:
        if feat_type=='deg':
            deg = adjs.sum(dim=-1).to(torch.long)
            feat = F.one_hot(deg, num_classes=nfeat).to(torch.float32)
        elif feat_type=='3cycle':
            feat = NodeCycleFeatures(10)(adjs, flags)[0]
        elif feat_type=='4cycle':
            feat = NodeCycleFeatures(10)(adjs, flags)[1]
        elif 'eig' in feat_type:
            idx = int(feat_type.split('eig')[-1])
            eigvec = EigenFeatures(idx)(adjs, flags)
            feat = eigvec[...,-1:] * feat_dict.scale
        elif feat_type=='ones':
            feat = torch.ones((adjs.size(0), adjs.size(1), 2), dtype=torch.float32, device=adjs.device)
        else:
            raise NotImplementedError(f'Feature: {feat_type} not implemented.')
        feature.append(feat)
        feat_dim.append(feat.shape[-1])
    feature = torch.cat(feature, dim=-1)

    return mask_x(feature, flags), feat_dim 


def feat_diff(x, adjs, flags, feat_dict):
    feat_diff = []
    sdim = 0
    indices = []
    for feat_type in feat_dict.type:
        if 'eig' in feat_type:
            indices.append(int(feat_type.split('eig')[-1]))
    if len(indices)>0:
        try:
            eigvec = EigenFeatures(max(indices))(adjs, flags)
        except:
            return [-1]*len(feat_dict.type)
        # eigvec = EigenFeatures(max(indices))(adjs, flags)

    for feat_type, feat_dim in zip(feat_dict.type, feat_dict.dim):
        x_ = x[:,:,sdim:sdim+feat_dim]
        if feat_type=='ones':
            feat = torch.ones((adjs.size(0), adjs.size(1), 2), dtype=torch.float32, device=adjs.device)
            x_feat = x_
            fdiff = (x_feat - feat).abs().sum(-1) / flags.sum(-1)
            feat_diff.append(round(fdiff.mean().item(),2))
        elif 'eig' in feat_type:
            idx = int(feat_type.split('eig')[-1])
            x_feat = x_ / feat_dict.scale
            x_pm = (x_feat.squeeze(-1)[:,0] / x_feat.squeeze(-1)[:,0].abs())
            eig_pm = (eigvec[...,idx-1][:,0] / eigvec[...,idx-1][:,0].abs())
            pm = x_pm * eig_pm
            fdiff = (x_feat.squeeze(-1) - eigvec[...,idx-1] * pm[:,None]).abs().square().sum(-1) / flags.sum(-1) #[:,None]
            feat_diff.append(round(fdiff.mean().item(),2))
        else:
            if feat_type=='deg':
                feat = adjs.sum(dim=-1).to(torch.long)
            elif feat_type=='3cycle':
                feat = NodeCycleFeatures(10)(adjs, flags, is_onehot=False)[0]
            elif feat_type=='4cycle':
                feat = NodeCycleFeatures(10)(adjs, flags, is_onehot=False)[1]
            else:
                raise NotImplementedError(f'Feature: {feat_type} not implemented.')
            x_feat = torch.argmax(x_, dim=-1)
            fdiff = (x_feat - feat).abs().sum(-1) / flags.sum(-1)
            feat_diff.append(round(fdiff.mean().item(),2))
        sdim += feat_dim
    return feat_diff


def graphs_to_dataloader(config, graph_list, return_feat_dim=False):
    adjs_tensor = graphs_to_tensor(graph_list, config.data.max_node_num) 
    x_tensor, feat_dim = init_features(config.data.feat, adjs_tensor, config.data.max_feat_num) 

    dataset = TensorDataset(x_tensor, adjs_tensor)
    dataloader = DataLoader(dataset, batch_size=config.data.batch_size, shuffle=True)
    if return_feat_dim:
        return dataloader, feat_dim
    return dataloader


def dataloader(config, get_graph_list=False):
    with open(f'{config.data.dir}/{config.data.data}.pkl', 'rb') as f:
        train_graphs, val_graphs, test_graphs = pickle.load(f)
    print(f'Dataset sizes: train {len(train_graphs)}, val {len(val_graphs)}, test {len(test_graphs)}')
    if get_graph_list:
        return train_graphs, val_graphs, test_graphs
    train_loader, feat_dim = graphs_to_dataloader(config, train_graphs, True)
    val_loader = graphs_to_dataloader(config, val_graphs)
    test_loader = graphs_to_dataloader(config, test_graphs)
    return train_loader, val_loader, test_loader, feat_dim


def preprocess(data_dir='data', dataset='sbm', measure_train_mmd=False):
    filename = f'{data_dir}/'
    if dataset == 'sbm':
        filename += 'sbm_200.pt'
    elif dataset == 'planar':
        filename += 'planar_64_200.pt'
    elif dataset == 'proteins':
        filename += 'proteins_100_500.pt'
    else:
        raise NotImplementedError(f'Dataset {dataset} not implemented.')

    if os.path.isfile(filename):
        adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, same_sample, n_max = torch.load(filename)
        print(f'Dataset {filename} loaded from file')
        test_len = int(round(len(adjs)*0.2))
        train_len = int(round((len(adjs) - test_len)*0.8))
        val_len = len(adjs) - train_len - test_len

        train_set, val_set, test_set = random_split(adjs, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(1234))
        train_graphs, val_graphs, test_graphs = tensor_to_graphs(train_set), tensor_to_graphs(val_set), tensor_to_graphs(test_set)

        with open(f'{data_dir}/{dataset}.pkl', 'wb') as f:
            pickle.dump(obj=(train_graphs, val_graphs, test_graphs), file=f, protocol=pickle.HIGHEST_PROTOCOL)
        
        if measure_train_mmd:
            from evaluation.stats import degree_stats, orbit_stats_all, clustering_stats, spectral_stats, \
                                        eval_sbm, eval_planar, connected_stats
            kernel = 'tv'
            train_mmd_degree = degree_stats(test_graphs, train_graphs, kernel)
            train_mmd_4orbits = orbit_stats_all(test_graphs, train_graphs, kernel)
            train_mmd_clustering = clustering_stats(test_graphs, train_graphs, kernel)    
            train_mmd_spectral = spectral_stats(test_graphs, train_graphs, kernel)
            train_conn = connected_stats(test_graphs, train_graphs)
            print(f'TV measures of Training set vs Validation set: ')
            print(f'Deg.: {train_mmd_degree:.4f}, Clus.: {train_mmd_clustering:.4f} '
                    f'Orbits: {train_mmd_4orbits:.4f}, Spec.: {train_mmd_spectral:.4f}, Conn.:{train_conn}')
            if dataset=='sbm' or dataset=='planar':
                val_fn = eval_sbm if dataset=='sbm' else eval_planar
                train_uniq, train_uniq_non_iso, train_eval = val_fn(test_graphs, train_graphs)
                print(f'V.U.N.: {train_eval} Uniq.: {train_uniq} U.N.: {train_uniq_non_iso}')

def tensor_to_graphs(adjs):
    graph_list = [nx.from_numpy_matrix(adj.cpu().detach().numpy()) for adj in adjs]
    return graph_list


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='sbm')
    parser.add_argument('--mmd', action='store_true')
    args = parser.parse_known_args()[0]

    preprocess(dataset=args.data, measure_train_mmd=args.mmd)