from multiprocessing import reduction
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.optim.lr_scheduler import StepLR
import networkx as nx
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import time
from dgl.dataloading import GraphDataLoader
import dgl

from dataset import collator_same_size, collator_classical


def training_loop_quantum(model, optimizer, dataset, epochs, batch_size=10, train_size=.8, val_size=.1, display_every=10, loss='binary', print_batch=False, quantum_every=1, device='cpu'):
    
    device = torch.device(device)
    num_examples = len(dataset)
    num_train = int(num_examples * train_size)
    num_val = int(num_examples * val_size)
    num_test = num_examples - num_train - num_val


    train_indices = torch.arange(num_train)
    val_indices = torch.arange(num_train, num_train + num_val)
    test_indices = torch.arange(num_train + num_val, num_examples)

    nodes = torch.tensor([item[0].number_of_nodes() for item in dataset])
    train_sample = train_indices[torch.argsort(nodes[train_indices])]
    val_sample = val_indices[torch.argsort(nodes[val_indices])]
    test_sample = test_indices[torch.argsort(nodes[test_indices])]

    
    sample_train_small = train_indices[nodes[train_indices]<=15]
    sample_train_medium = train_indices[(nodes[train_indices]>15) & (nodes[train_indices]<=18)]
    sample_train_big = train_indices[nodes[train_indices]>18]
    sample_train_small = sample_train_small[np.argsort(nodes[sample_train_small])]
    sample_train_medium = sample_train_medium[np.argsort(nodes[sample_train_medium])]
    sample_train_big = sample_train_big[np.argsort(nodes[sample_train_big])]
    print(len(sample_train_small), len(sample_train_medium), len(sample_train_big))
      
    
    train_dataloader_small = DataLoader(
        dataset, sampler=sample_train_small, batch_size=30, drop_last=False, collate_fn=collator_same_size)
    train_dataloader_medium = DataLoader(
        dataset, sampler=sample_train_medium, batch_size=60, drop_last=False, collate_fn=collator_same_size)
    train_dataloader_big = DataLoader(
        dataset, sampler=sample_train_big, batch_size=20, drop_last=False, collate_fn=collator_same_size)
    
    dataloaders_train = [train_dataloader_small, train_dataloader_medium, train_dataloader_big]
    
    

    train_dataloader_classical = DataLoader(
        dataset, sampler=train_sample, batch_size=200, drop_last=False, collate_fn=collator_same_size)
    val_dataloader = DataLoader(
        dataset, sampler=val_sample, batch_size=200, drop_last=False, collate_fn=collator_same_size)
    test_dataloader = DataLoader(
        dataset, sampler=test_sample, batch_size=200, drop_last=False, collate_fn=collator_same_size)
    
    dataloaders_all = [train_dataloader_small, train_dataloader_medium, train_dataloader_big, val_dataloader, test_dataloader]
    

    dataloader = DataLoader(
        dataset, sampler=torch.arange(num_examples), batch_size=batch_size, drop_last=False, collate_fn=collator_same_size)
    
    if loss =='mse':
        loss_func = F.mse_loss
        loss_func_test = F.l1_loss
    else:
        loss_func = F.cross_entropy
        loss_func_test = F.cross_entropy
    acc_test = []
    losses_test = []
    losses_val = []
    

    for epoch in range(epochs):
        t0 = time.perf_counter()
        batch_num = 0
        model.train()
        
        for layer in list(model.children())[0:-1]:
            for p in layer.attention.parameters():
                if epoch % quantum_every == 0:
                    p.requires_grad=True
                else:
                    p.requires_grad=False

        if epoch % quantum_every != 0:
            for dataloader in [train_dataloader_classical]:
                for batched_graph, labels, ising, features, _ in dataloader:
                    labels = labels.to(device)
                    attention1 = [feat['conv1'].to(device) for feat in features]
                    attention2 = [feat['conv2'].to(device) for feat in features]
                    pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device),
                                 ising, precomputed_attention1=attention1,
                                 precomputed_attention2=attention2)
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del ising
                    del pred
                    torch.cuda.empty_cache()

        else:
            for dataloader in dataloaders_train:
                for batched_graph, labels, ising, features, _ in dataloader:
                    tb0 = time.perf_counter()
                    labels = labels.to(device)
                    observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                    tp1 = time.perf_counter()
                    if len(ising.shape) > 1:
                       # observables = [obs.unsqueeze(3) for obs in observables]
                       # observables = torch.cat(observables, axis=3)
                        ising = ising.to(device)
                        tp2 = time.perf_counter()
                        batch_size = ising.shape[1]
                        pred = model(batched_graph.to(device),
                                     batched_graph.ndata['feat'].to(device),
                                     ising,
                                     unbatch=False,
                                     batch_size=batch_size,
                                     observables=None)
                        tp3 = time.perf_counter()
                    else:
                        observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                        ising = [ising[i].to(device) for i in range(len(ising))]
                        tp2 = time.perf_counter()
                        pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device),
                                     ising, observables=None)
                        tp3 = time.perf_counter()
                    tp4 = time.perf_counter()
                    
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    tp5 = time.perf_counter()
                    optimizer.step()
                    tp6 = time.perf_counter()


                    del pred
                    del labels
                    del batched_graph
                    torch.cuda.empty_cache()
                    batch_num += 1
                    tb1 = time.perf_counter()

                    if print_batch:
                        print("Epoch {} | Batch {} | F {} | B {} | O {}".format(epoch+1, batch_num, tp3 - tp2, tp5-tp4, tp6-tp5))
                    tp7 = time.perf_counter()
                    
                   # print(f'{batch_num} |{tp1-tp0}|{tp2-tp1}|{tp3-tp2}|{tp5-tp4}|{tp6-tp5}|{tp7-tp6}')
                    
                t1 = time.perf_counter()
            with torch.no_grad():
                for dataload in dataloaders_all:
                    for batched_graph, labels, ising, features, indices in dataload:
                        observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                        if len(ising.shape) > 1:
                            ising = ising.to(device)
                            #observables = [obs.unsqueeze(3) for obs in observables]
                            #observables = torch.cat(observables, axis=3)
                            ising = ising.to(device)
                            batch_size = ising.shape[1]
                            N = int(batched_graph.batch_num_nodes()[0])
                            mat = model.conv1.attention(N, ising, batch_size=batch_size, observables=None)
                            if len(mat.shape)==2:
                                mat = mat.unsqueeze(2)
                            for i, index in enumerate(indices):
                                dataset.update_graph_features(mat[:, :, i], index, 'conv1')
                            mat = model.conv2.attention(N, ising, batch_size=batch_size, observables=None)
                            if len(mat.shape)==2:
                                mat = mat.unsqueeze(2)
                            for i, index in enumerate(indices):
                                dataset.update_graph_features(mat[:, :, i], index, 'conv2')

                        else:
                            graphs = dgl.unbatch(batched_graph)
                            for ising_matrix, graph, obs, index in zip(ising, graphs, observables, indices):
                                ising_matrix = ising_matrix.to(device)
                                N = graph.num_nodes()
                                mat = model.conv1.attention(N, ising_matrix, batch_size=1, observables=None)
                                dataset.update_graph_features(mat, index, 'conv1')
                                mat = model.conv2.attention(N, ising_matrix, batch_size=1, observables=None)
                                dataset.update_graph_features(mat, index, 'conv2')
                        del ising
                        del mat
                        del labels
                        torch.cuda.empty_cache()
        t2 = time.perf_counter()
        
        with torch.no_grad():
            model.eval()
            num_correct = 0
            num_tests = 0
            loss_test = 0
            loss_val = 0
            for batched_graph, labels, ising, features, _ in test_dataloader:
                labels = labels.to(device)
                attention1 = [feat['conv1'].to(device) for feat in features]
                attention2 = [feat['conv2'].to(device) for feat in features]
                pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device),
                            ising, precomputed_attention1=attention1,
                            precomputed_attention2=attention2)
                #num_correct += (pred.argmax(1) == labels).sum().item()
                loss_test += loss_func_test(pred, labels, reduction='sum') / num_test

            for batched_graph, labels, ising, features, _ in val_dataloader:
                labels = labels.to(device)
                attention1 = [feat['conv1'].to(device) for feat in features]
                attention2 = [feat['conv2'].to(device) for feat in features]
                pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device),
                            ising, precomputed_attention1=attention1,
                            precomputed_attention2=attention2)
                #num_correct += (pred.argmax(1) == labels).sum().item()
                loss_val += loss_func_test(pred, labels, reduction='sum') / num_test
            losses_test.append(loss_test.cpu().numpy())
            losses_val.append(loss_val.cpu().numpy())

        #acc_test.append(num_correct/num_tests)
        acc_test = [0]
        #torch.save(model.state_dict(), f'models_saved/model_{epoch}.pt')
        np.save(f'losses_saved/loss_classical_val_{epoch}.npy', np.array(losses_val))
        np.save(f'losses_saved/loss_classical_test_{epoch}.npy', np.array(losses_test))


        if (epoch +1) % display_every == 0:
            with torch.no_grad():
                model.eval()
                loss = 0
                for dataload in dataloaders_train:
                    for batched_graph, labels, ising, features, _ in dataload:
                        labels = labels.to(device)
                        attention1 = [feat['conv1'].to(device) for feat in features]
                        attention2 = [feat['conv2'].to(device) for feat in features]
                        pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device),
                                    ising, precomputed_attention1=attention1,
                                    precomputed_attention2=attention2)
                        loss += loss_func(pred, labels, reduction='sum') / num_train  
     
            print("Epoch {} | Time 1 {} | Time 2 {} | Loss {} | Acc test {} | Loss val {} | Loss test {}".format(epoch+1, t1-t0, t2-t0, loss, acc_test[-1], loss_val, loss_test))

    return acc_test


def training_loop(model, optimizer, dataset, epochs, batch_size=10, train_size=.8, val_size=.1, display_every=10, loss='binary', print_batch=False, gamma=1., device='cpu'):

    device = torch.device(device)
    num_examples = len(dataset)
    num_train = int(num_examples * train_size)
    num_val = int(num_examples * val_size)
    num_test = num_examples - num_train - num_val

    train_sampler = SubsetRandomSampler(torch.arange(num_train))
    val_sampler = SubsetRandomSampler(torch.arange(num_train, num_train + num_val))
    test_sampler = SubsetRandomSampler(torch.arange(num_train + num_val, num_examples))

    train_dataloader = DataLoader(
        dataset, sampler=train_sampler, batch_size=batch_size, drop_last=False, collate_fn=collator_classical)
    val_dataloader = DataLoader(
        dataset, sampler=val_sampler, batch_size=batch_size, drop_last=False, collate_fn=collator_classical)
    test_dataloader = DataLoader(
        dataset, sampler=test_sampler, batch_size=batch_size, drop_last=False, collate_fn=collator_classical)

    if loss =='mse':
        loss_func = F.mse_loss
        loss_func_test = F.l1_loss
    else:
        loss_func = F.cross_entropy
        loss_func_test = F.cross_entropy
    acc_test = []
    losses_test = []
    losses_val = []
    acc_test = []
    scheduler = StepLR(optimizer, step_size=20, gamma=gamma)
    

    for epoch in range(epochs):
        t0 = time.perf_counter()
        batch_num = 0
        for batched_graph, labels, indices in train_dataloader:
            labels = labels.to(device)
            pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device))
            loss = loss_func(pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            batch_num += 1
            if print_batch:
                print("Epoch {} | Batch {} | Loss {}".format(epoch+1, batch_num, loss))
        t1 = time.perf_counter()
        scheduler.step()
        acc_test = [0]

        loss_val, loss_test = 0., 0.

        with torch.no_grad():
            for batched_graph, labels, indices in val_dataloader:
                labels = labels.to(device)
                pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device))
                loss_val += loss_func_test(pred, labels, reduction='sum').cpu().numpy() / num_val

            for batched_graph, labels, indices in test_dataloader:
                labels = labels.to(device)
                pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device))
                loss_test += loss_func_test(pred, labels, reduction='sum').cpu().numpy() / num_test

            losses_test.append(loss_test)
            losses_val.append(loss_val)

        if (epoch +1) % display_every == 0:
            with torch.no_grad():
                loss = 0
                for batched_graph, labels, indices in train_dataloader:
                    labels = labels.to(device)
                    pred = model(batched_graph.to(device), batched_graph.ndata['feat'].to(device))
                    loss += loss_func(pred, labels) / num_train
                
        
            print("Epoch {} | Time {} | Loss {} | Acc test {} | Loss val {} | Loss test {}".format(epoch+1, t1-t0, loss, 0, loss_val, loss_test))

        torch.save(model.state_dict(), f'models_saved/gat/model_{epoch}.pt')
        np.save(f'losses_saved/gcn/loss_val_{epoch}.npy', np.array(losses_val))
        np.save(f'losses_saved/gcn/loss_test_{epoch}.npy', np.array(losses_test))


    return acc_test
