import os
import datetime

import numpy as np
import scipy.sparse as sp

import torch
import torch_geometric
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.linkproppred import PygLinkPropPredDataset
from data_management import DataStructure
from model_transformer import TransformerNodeClassifier, TransformerLinkPredictor

# %%

def create_downstream_model(task, configs, num_hops, num_features, num_classes, device, state_dict=None):
    if task == 'node_classification':
        model = TransformerNodeClassifier(
            stems=configs['stems'],
            backbone=configs['backbone'],
            num_hops=num_hops,
            num_features=num_features,
            num_classes=num_classes,
            device=device,
            state_dict=state_dict)
    elif task == 'link_prediction':
        model = TransformerLinkPredictor(
            stems=configs['stems'],
            backbone=configs['backbone'],
            num_hops=num_hops,
            num_features=num_features,
            device=device,
            state_dict=state_dict
        )
    else:
        raise (NotImplementedError(f'Task type \'{task}\' is not implemented.'))

    return model


def create_data_structure(datasets, ssl, task=None, model_type='transformer'):
    neighbourhood_aggregation = 2 if model_type == 'transformer' else None

    if ssl:
        d = []
        if ',' in datasets:
            datasets = datasets.split(',')
        for dataset in datasets:
            if dataset in ['cora', 'cora_ml', 'citeseer', 'dblp', 'pubmed']:
                d.append(torch_geometric.datasets.CitationFull(root='./data/temp/citation_full', name=dataset))
            elif 'planetoid' in dataset:
                d.append(torch_geometric.datasets.Planetoid(root='./data/temp/planetoid', name=dataset[:-10]))
            elif 'ogbn' in dataset:
                d.append(PygNodePropPredDataset(name=dataset))
            elif 'ogbl' in dataset:
                d.append(PygLinkPropPredDataset(name=dataset))
            elif dataset == 'ppi':
                d.append(torch_geometric.datasets.PPI(root='./data/temp/PPI'))
            elif 'amazon' in dataset and 'hetphl' not in dataset:
                if 'products' in dataset:
                    d.append(torch_geometric.datasets.AmazonProducts(root='./data/temp/AmazonProducts'))
                else:
                    d.append(torch_geometric.datasets.Amazon(root='./data/temp/Amazon', name=dataset[7:]))
            elif 'hetphl' in dataset:
                t_npz = np.load(f'./data/temp/heterophilous/{dataset[7:]}.npz')
                temp = torch_geometric.data.Data(x=torch.as_tensor(t_npz['node_features']),
                                                 edge_index=torch_geometric.utils.to_undirected(torch.as_tensor(t_npz['edges'].T)),
                                                 y=torch.as_tensor(t_npz['node_labels']))
                d.append(Dummydatasets(temp,dataset[7:]))
                # d.append(torch_geometric.datasets.HeterophilousGraphDataset(root='./data/temp/Heterophilous', name=dataset[7:]))
            else:
                raise(NotImplementedError(f'Dataset {dataset} not implemented'))
        data = DataStructure(d, neighbourhood_aggregation=neighbourhood_aggregation)
    else:
        if datasets in ['cora', 'cora_ml', 'citeseer', 'dblp', 'pubmed']:
            data = DataStructure([torch_geometric.datasets.CitationFull(root='./data/temp/citation_full',
                                                                        name=datasets)],
                                 task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)
            data = data.datasets[0]
        elif datasets in ['cora-planetoid', 'citeseer-planetoid', 'pubmed-planetoid']:
            data = DataStructure([torch_geometric.datasets.Planetoid(root='./data/temp/planetoid',
                                                                     name=datasets[:-10])],
                                 task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)
            data = data.datasets[0]
        elif 'ogbn' in datasets:
            data = DataStructure([PygNodePropPredDataset(name=datasets)],
                                 task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)
            data = data.datasets[0]
            data.data.y = data.data.y.squeeze()
        elif 'amazon' in datasets and 'hetphl' not in datasets:
            if 'products' in dataset:
                data = DataStructure([torch_geometric.datasets.AmazonProducts(root='./data/temp/AmazonProducts')],
                                    task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)                
                data = data.datasets[0]
            else:
                data = DataStructure([torch_geometric.datasets.Amazon(root='./data/temp/Amazon', name=datasets[7:])],
                                    task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)
                data = data.datasets[0]
        elif 'hetphl' in datasets:
            t_npz = np.load(f'./data/temp/heterophilous/{datasets[7:]}.npz')
            temp = torch_geometric.data.Data(x=torch.as_tensor(t_npz['node_features']),
                                                edge_index=torch_geometric.utils.to_undirected(torch.as_tensor(t_npz['edges'].T)),
                                                y=torch.as_tensor(t_npz['node_labels']))        
            
            data = DataStructure([Dummydatasets(temp,datasets[7:])],
                                 task=task, neighbourhood_aggregation=neighbourhood_aggregation, ssl=False)
            data = data.datasets[0]

    return data


def create_save_path(model_name):
    date_folder = datetime.datetime.now().strftime(format='%d_%m_%y')
    if os.path.isdir("./models/"):
        pass
    else:
        os.mkdir("./models/")
    if os.path.isdir(os.path.join("./models/", date_folder)):
        pass
    else:
        os.mkdir(os.path.join("./models/", date_folder))

    i = 0
    expt_folder = 'expt_' + str(i)
    while True:
        if os.path.isdir(os.path.join("./models/", date_folder, expt_folder)):
            i = i + 1
            expt_folder = 'expt_' + str(i)
            continue
        else:
            os.mkdir(os.path.join("./models/", date_folder, expt_folder))
            break

    save_path = {
        'folder': os.path.join("./models/", date_folder, expt_folder),
        'file': os.path.join("./models/", date_folder, expt_folder, model_name),
    }

    return save_path


def normalize_features(mx):
    """Row-normalize sparse matrix"""
    row_sum = np.array(mx.sum(1))
    r_inv = np.power(row_sum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def normalize_adj(mx):
    """Row-column-normalize sparse matrix"""
    row_sum = np.array(mx.sum(1))
    r_inv = np.power(row_sum, -1 / 2).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx).dot(r_mat_inv)
    return mx


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


def accuracy_batch(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def torch_sparse_tensor_to_sparse_mx(torch_sparse):
    """Convert a torch sparse tensor to a scipy sparse matrix."""
    m_index = torch_sparse._indices().numpy()
    row = m_index[0]
    col = m_index[1]
    data = torch_sparse._values().numpy()
    sp_matrix = sp.coo_matrix((data, (row, col)), shape=(torch_sparse.size()[0], torch_sparse.size()[1]))
    return sp_matrix


def re_features(adj, features, K):
    nodes_features = torch.empty(features.shape[0], 1, K + 1, features.shape[1])

    for i in range(features.shape[0]):
        nodes_features[i, 0, 0, :] = features[i]

    x = features + torch.zeros_like(features)

    for i in range(K):

        x = torch.matmul(adj[0], x)

        for index in range(features.shape[0]):
            nodes_features[index, 0, i + 1, :] = x[index]

    nodes_features = nodes_features.squeeze()

    return nodes_features


def nor_matrix(adj, a_matrix):
    nor_matrix = torch.mul(adj, a_matrix)
    row_sum = torch.sum(nor_matrix, dim=1, keepdim=True)
    nor_matrix = nor_matrix / row_sum

    return nor_matrix


class Dummydatasets():

    def __init__(self, d, name):
        self.data = d
        self.num_nodes = d.x.shape[0]
        self.num_node_features = d.x.shape[1]
        self.name = name
        self.num_classes = torch.unique(d.y).shape[0]
