import os
import sys
import os.path as osp
import numpy as np
import random
import copy
from tqdm import tqdm
import pickle

import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU

import torch_geometric.transforms as T
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.data import DataLoader, Batch, DataLoader, Dataset
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix, from_scipy_sparse_matrix
from semi.igsd import IGSD
from semi.argparser import args
from semi.dataset import load
from itertools import cycle
from scipy.sparse import csr_matrix
import networkx as nx
import torch
from scipy.linalg import fractional_matrix_power, inv
import scipy.sparse as sp
from typing import Optional
import warnings
#TODO ingore gru memory warning
warnings.filterwarnings('ignore')

class Encoder(torch.nn.Module):
    def __init__(self, num_features, hid_dim, edge_type=5):
        super(Encoder, self).__init__()
        self.lin0 = torch.nn.Linear(num_features, hid_dim)

        nn = Sequential(Linear(edge_type, 128), ReLU(), Linear(128, hid_dim * hid_dim))
        self.conv = NNConv(hid_dim, hid_dim, nn, aggr='mean', root_weight=False)
        self.gru = GRU(hid_dim, hid_dim)

        self.set2set = Set2Set(hid_dim, processing_steps=3)
        # self.lin1 = torch.nn.Linear(2 * dim, dim)
        # self.lin2 = torch.nn.Linear(dim, 1)

    def forward(self, data, latent=None):
        out = F.relu(self.lin0(data.x)) #[362,11] [,num_feat]
        h = out.unsqueeze(0)

        #feat_map = []
        for i in range(3):
            m = F.relu(self.conv(out, data.edge_index, data.edge_attr.float())) # out:[363,512]
            self.gru.flatten_parameters()
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)
            # print(out.shape) : [num_node x dim]
            #feat_map.append(out)

        out = self.set2set(out, data.batch.to(args.device)) # out:[363,512] -> [20,1024]
        return out, None #feat_map[-1]

class MyTransform(object):
    def __call__(self, data):
        # Specify target.
        data.y = data.y[:, args.target]
        return data


class Complete(object):
    def __call__(self, data):
        #device = data.edge_index.device
        row = torch.arange(data.num_nodes, dtype=torch.long, device=args.device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=args.device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data


def train(epoch):
    model.train()
    loss_all = 0
    sup_loss_all = 0
    unsup_loss_all = 0
    unsup_sup_loss_all = 0
    num_train = len(train_loader) # * train_loader.batch_size
    i = 0

    if args.use_unsup_loss:

        for data, adj, diff in tqdm(zip(cycle(train_loader), adj_loader, diff_loader)):
            adj = adj.to(args.device)
            diff = diff.to(args.device)
            optimizer.zero_grad()

            if i < num_train:
                unsup_loss = model.unsup_loss(adj, diff)
                del adj, diff
                data = data.to(args.device)
                logits = model(data)
                sup_loss = F.mse_loss(logits, data.y)
                loss = sup_loss + unsup_loss * args.lamda
                loss.backward()
                model.update_moving_average()

                sup_loss_all += sup_loss.item()
                unsup_loss_all += unsup_loss.item()
                loss_all += loss.item() * data.num_graphs
                optimizer.step()

            elif i > num_train*args.unsup_times:
                break
            else:
                unsup_loss = model.unsup_loss(adj, diff)
                del adj, diff
                data = data.to(args.device)
                unsup_loss.backward()
                model.update_moving_average()
                optimizer.step()

                unsup_loss_all += unsup_loss.item()

            i += 1
        '''
        for adj, diff in tqdm(zip(rest_adj_loader, rest_diff_loader)):
            unsup_loss = model.unsup_loss(adj, diff)
            del adj, diff
            data = data.to(args.device)
            unsup_loss.backward()
            model.update_moving_average()
            optimizer.step()

            unsup_loss_all += unsup_loss.item()
        '''

        print("(Property {}, Unsuptimes {}) Supervised total loss:{:.7f}, Unsupervised total loss:{:.7f}, Current Best MAE:{:.7f}".\
              format(args.target, args.unsup_times, sup_loss_all, unsup_loss_all, best_test_error))
        return loss_all / len(train_loader.dataset)
    else:
        for data in tqdm(train_loader):
            data = data.to(args.device)
            optimizer.zero_grad()
            logits = model(data)
            sup_loss = F.mse_loss(logits, data.y)  # [20]
            sup_loss.backward()
            loss_all += sup_loss.item() * data.num_graphs
            optimizer.step()

        return loss_all / len(train_loader.dataset)


def test(loader):
    #model.eval()
    with torch.no_grad():
        error = 0
        for data in loader:
            data = data.to(args.device)
            logits = model(data)
            error += (logits * std - data.y * std).abs().sum().item()  # MAE
    return error / len(loader.dataset)

def seed_everything(seed=1234):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):
    a = nx.convert_matrix.to_numpy_array(graph)
    if self_loop:
        a = a + np.eye(a.shape[0])                                # A^ = A + I_n
    d = np.diag(np.sum(a, 1))                                     # D^ = Sigma A^_ii
    dinv = fractional_matrix_power(d, -0.5)                       # D^(-1/2)
    at = np.matmul(np.matmul(dinv, a), dinv)                      # A~ = D^(-1/2) x A^ x D^(-1/2)
    return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at))   # a(I_n-(1-a)A~)^-1

def get_diff(dataloader, shuffle=False):
    data_list = []
    for data in dataloader:
        nx = to_networkx(data, node_attrs=data.x, edge_attrs=data.edge_attr)
        prr = compute_ppr(nx)
        data_list.append(prr)

    #for i in tqdm(range(adj.shape[0])):
        #TODO debug
        #if i > 100:
        #    break
        #edge_index, edge_weight = from_scipy_sparse_matrix(csr_matrix(adj[i]))
        #graph = Data(x=torch.from_numpy(feat[i]).float(), y=torch.Tensor([label[i]]), edge_index=edge_index, edge_attr=edge_weight.float())
        #data_list.append(graph)
    #graph_loader = DataLoader(data_list, batch_size=args.batch_size, shuffle=shuffle)
    print(data_list)
    return nx #graph_loader

if __name__ == '__main__':
    seed_everything()
    print(args)

    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')
    transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])
    dataset = QM9(path, transform=transform).shuffle()
    print('num_features : {}\n'.format(dataset.num_features))

    target = args.target
    dim = args.hid_dim
    epochs = args.num_epoch
    batch_size = args.batch_size
    lamda = args.lamda
    use_unsup_loss = args.use_unsup_loss
    separate_encoder = args.separate_encoder
    feat_dim = dataset.num_features #dataset.data.x.shape[1]

    # Normalize targets to mean = 0 and std = 1.
    mean = dataset.data.y[:, target].mean().item()
    std = dataset.data.y[:, target].std().item()
    dataset.data.y[:, target] = (dataset.data.y[:, target] - mean) / std

    # Split datasets.
    test_dataset = dataset[:10000]
    val_dataset = dataset[10000:20000]
    train_dataset = dataset[20000:20000 + args.train_num]

    #test_dataset = dataset[:100]
    #val_dataset = dataset[100:200]
    #train_dataset = dataset[200:300]

    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    #test = to_scipy_sparse_matrix(dataset.data.edge_index)  # , dataset.data.edge_attr)
    #print()
    #test_diff = get_diff(test_loader)
    #test_diff = gdc(test_dataset.data)

    if use_unsup_loss:
        unsup_adj = pickle.load(open('./semi/data/QM9/unsup_adj_all.pkl', 'rb'))
        unsup_diff = pickle.load(open('./semi/data/QM9/unsup_diff_all.pkl', 'rb'))
        #unsup_train_dataset = dataset[20000:30000]
        #unsup_train_dataset = dataset[200:1000]
        adj_loader = DataLoader(unsup_adj, batch_size=batch_size, shuffle=True)
        diff_loader = DataLoader(unsup_diff, batch_size=batch_size, shuffle=True)
        idx = int(args.train_num * args.unsup_times)
        #rest_adj_loader = DataLoader(unsup_adj[idx:], batch_size=batch_size, shuffle=True)
        #rest_diff_loader = DataLoader(unsup_diff[idx:], batch_size=batch_size, shuffle=True)
        print("Size: Training:{}, Validation:{}, Testing:{}, Unsup:{}".\
              format(len(train_dataset), len(val_dataset),len(test_dataset), len(unsup_adj))) #/5
    else:
        print("Size: Training:{}, Validation:{}, Testing:{}".format(len(train_dataset), len(val_dataset),len(test_dataset)))

    '''
    #TODO debug
    gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
                     normalization_out='col',
                     diffusion_kwargs=dict(method='ppr', alpha=0.2), exact=True) #,  # (method='ppr', alpha=0.2),
                     #sparsification_kwargs=dict(method='topk', k=16,dim=0), exact=True)
    id_list,data_list = [],[]
    for i, adj in tqdm(enumerate(unsup_train_dataset)):
        diff_attrs = []
        for attr_idx in range(adj.edge_attr.shape[1]):
            adj_copy = copy.deepcopy(adj)
            adj_copy.edge_attr = adj.edge_attr[:, attr_idx]
            # adj_copy = adj_copy.to(args.device)
            try:
                diff = gdc(adj_copy)
                diff_attrs.append(diff.edge_attr)
            except Exception as e:
                print(e)
                continue
        if len(diff_attrs) == adj.edge_attr.shape[1]:
            diff_attrs = torch.stack(diff_attrs)
            diff.edge_index = diff.edge_index  # [:,:100]
            diff.edge_attr = diff_attrs.T  # [:100,:]
            id_list.append(i)
            data_list.append(diff)
            print("Append {}".format(i))

        if len(id_list) == 6000:
            pickle.dump('./semi/data/QM9/unsup.pkl', data_list)
            print("saved")
            break
            #np.save('./semi/data/QM9/unsup_id.npy', np.array(id_list))
    '''

    #data_list = pickle.load('./semi/data/QM9/unsup.pkl')
    #unsup_data = DataLoader(data_list, batch_size=args.batch_size)
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    online_encoder = Encoder(num_features=feat_dim, hid_dim=args.hid_dim, edge_type=1)
    sup_encoder = Encoder(num_features=feat_dim, hid_dim=args.hid_dim)
    model = IGSD(online_encoder, sup_encoder, feat_dim, hidden_layer=args.num_layer,\
                 projection_size=args.projection_size, projection_hidden_size=args.projection_hidden_size)
    model.to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001)

    # TODO debug
    # val_error = test(val_loader)
    # test_error = test(test_loader)
    # print('Epoch: {:03d}, Validation MAE: {:.7f}, Test MAE: {:.7f},'.format(0, val_error, test_error))

    best_val_error = None
    best_test_error = 1e7
    for epoch in range(1, epochs):
        lr = scheduler.optimizer.param_groups[0]['lr']
        loss = train(epoch)
        val_error = test(val_loader)
        scheduler.step(val_error)

        if best_val_error is None or val_error <= best_val_error:
            print('Update')
            test_error = test(test_loader)
            best_val_error = val_error

        if best_test_error > test_error:
            best_test_error = test_error

        print('(Target{}) Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Validation MAE: {:.7f}, '
              'Test MAE: {:.7f}, Best Test MAE: {:.7f}'.format(args.target, epoch, lr, loss, val_error, test_error, best_test_error))

    with open('supervised.log', 'a+') as f:
        f.write('{},{},{},{},{},{},{}\n'.format(target, args.train_num, use_unsup_loss, args.lamda,
                                                   args.weight_decay, val_error, test_error))

    try:
        torch.save(model, 'saved_models/{}.model'.format(target))
    except Exception as e:
        print(e)
