import torch
import networkx as nx
import numpy as np
from torch.utils.data.sampler import SequentialSampler
import dgl
from dgl.dataloading import GraphDataLoader
from torch.utils.data import DataLoader
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import time


from dataset import load_dataset, NXDataset, collator_same_size
from training import training_loop_quantum, training_loop
from models import QGraphClassification, CGraphRegression, GCNRegression
from utils import obs_ZZ, generate_ising_matrices_torch, return_observables_torch
from parallel_training import training_loop_parallel

import warnings
warnings.filterwarnings("ignore")

def main():
    
    device = torch.device('cpu')
    
    
    max_node = 20
    graphs, targets = load_dataset('PTC_FM', min_node=4, max_node=max_node)
    graphs = [nx.convert_node_labels_to_integers(G) for G in graphs]
    graphs = np.array(graphs, dtype=object)
    dataset = NXDataset(graphs, targets.reshape((-1, 1)), max_node, shuffle=True, seed=67, device=device)
    dataset.compute_obs_shortest_path(max_node, device='cpu')
            
    ZZ_matrices = dict()
    for n in range(2, max_node+1):
        matrices = dict()
        for i in range(n):
            for j in range(i, n):
                matrix = obs_ZZ(n, i, j)
                matrices[(i, j)] = matrix              
        ZZ_matrices[n] = matrices
        del matrices

    obs = return_observables_torch(max_N=max_node, device='cpu', precomputed_zz=ZZ_matrices)

    model = QGraphClassification(4, 64, 2, 4, None, apply_softmax=True, only_neighbors=False).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


   # test_acc = training_loop_quantum(model, optimizer, dataset, 20,
   #                                     batch_size=20, display_every=1, train_size=.8, quantum_every=10, device=device)
    test_acc = training_loop_parallel(model, dataset, 10,
                                        batch_size=20, display_every=1, train_size=.8, quantum_every=5)


def qm9_training():
    max_node = 20

    device = torch.device('cpu')

    t0 = time.time()
    graphs, targets = load_dataset('QM9', min_node=4, max_node=max_node)
    print(len(graphs))

    np.random.seed(62)
    sample = np.random.choice(len(graphs), size=10000, replace=False)
    graphs_sample = graphs[sample]
    targets_sample = targets[sample]

    scaler = StandardScaler()
    targets = scaler.fit_transform(targets_sample)

    graphs = [nx.convert_node_labels_to_integers(G) for G in graphs_sample]
    graphs = np.array(graphs, dtype=object)
   # targets = np.ones((len(graphs), 1))

    ZZ_matrices = dict()
    for n in range(2, max_node+1):
        break
        matrices = dict()
        for i in range(n):
            for j in range(i, n):
                matrix = obs_ZZ(n, i, j)
                matrices[(i, j)] = matrix              
        ZZ_matrices[n] = matrices
        del matrices

    #obs = return_observables_torch(max_N=max_node, device='cpu', precomputed_zz=ZZ_matrices)

    t1 = time.time()
    print(t1 - t0)
   # dataset = NXDataset(graphs, targets, max_node, shuffle=True, seed=67, device=device, classification=False)
    t2 = time.time()
    print(t2- t1)

   # dataset.compute_obs_shortest_path(max_node, device='cpu')
    t3 = time.time()
    print(t3- t2)


   # model = QGraphClassification(4, 1024, 19, 4, None, apply_softmax=True, only_neighbors=False).to(device)
   # optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

  #  test_acc = training_loop_parallel(model, dataset, 6,
  #                                      batch_size=20, display_every=1, train_size=.8, quantum_every=3, loss='mse', print_batch=True)

    torch.random.manual_seed(32)
    device=torch.device('cuda')
    model = QGraphClassification(4, 1024, 19, 4, None, apply_softmax=True, only_neighbors=False).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

    loss_val = np.load('losses_saved/loss_val_499.npy')
    best_model = np.argmin(loss_val)
    path = f'models_saved/model_{best_model}.pt'
    model.load_state_dict(torch.load(path))

    dataset = NXDataset(graphs, targets, max_node, shuffle=True, seed=67, device='cpu', classification=False)
   # dataset.compute_obs_shortest_path(max_node, device='cpu')                                 

    test_acc = training_loop_quantum(model, optimizer, dataset, 500,
                                            batch_size=10, display_every=1, train_size=.8, quantum_every=1000, device=device, loss='mse')



def qm9_training_classical():
    max_node = 20

    device = torch.device('cpu')

    t0 = time.time()
    graphs, targets = load_dataset('QM9', min_node=4, max_node=max_node)
    print(len(graphs))

    np.random.seed(62)
    sample = np.random.choice(len(graphs), size=10000, replace=False)
    graphs_sample = graphs[sample]
    targets_sample = targets[sample]

    scaler = StandardScaler()
    targets = scaler.fit_transform(targets_sample)

    graphs = [nx.convert_node_labels_to_integers(G) for G in graphs_sample]
    graphs = np.array(graphs, dtype=object)

    torch.random.manual_seed(32)
    device=torch.device('cuda')

    dataset = NXDataset(graphs, targets, max_node, shuffle=True, seed=67, device='cpu', classification=False, compute_ising=False)

    model = GCNRegression(4, 1024, 19, n_layers=1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    test_acc = training_loop(model, optimizer, dataset, 500,
                                            batch_size=200, display_every=1, train_size=.8, loss='mse', device=device)



if __name__=='__main__':
    qm9_training()
    #main()
