import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.optim.lr_scheduler import StepLR
import networkx as nx
import numpy as np
from torch.utils.data import DataLoader, Subset
from torch.utils.data.sampler import SubsetRandomSampler
import time
import dgl
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import os
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_descriptor')
#mp.set_start_method('forkserver')

#import torchvision
#import torchvision.transforms as transforms
from dataset import load_dataset, NXDataset

from dataset import collator_same_size, collator

import warnings
warnings.filterwarnings("ignore")



def training_local(rank, world_size, model, dataset, epochs, batch_size, train_size, display_every, quantum_every, save_file=None, print_batch=False, loss='crossentropy'):
        
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    local_dev = 'cuda:'+str(rank)
    model = model.to(local_dev)
    ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    
    print('Process ', str(rank))
    
    #sampler = DistributedSampler(dataset=dataset)
    #dataloader = DataLoader(
    #    dataset, sampler=sampler, batch_size=10, drop_last=False, num_workers=world_size, collate_fn=collator)
    #for data in dataloader:
    #    print(rank, data)
    
    
    
    
    num_examples = len(dataset)
    num_train = int(num_examples * train_size)
    train_indices = torch.arange(num_train)
    test_indices = torch.arange(num_train, num_examples)
    nodes = torch.tensor([item[0].number_of_nodes() for item in dataset])
    train_sample = train_indices[torch.argsort(nodes[train_indices])]
    test_sample = train_indices[torch.argsort(nodes[test_indices])]
    
    sample_train_small = train_indices[nodes[train_indices]<=15]
    sample_train_medium = train_indices[(nodes[train_indices]>15) & (nodes[train_indices]<=18)]
    sample_train_big = train_indices[nodes[train_indices]>18]
    sample_train_small = sample_train_small[np.argsort(nodes[sample_train_small])]
    sample_train_medium = sample_train_medium[np.argsort(nodes[sample_train_medium])]
    sample_train_big = sample_train_big[np.argsort(nodes[sample_train_big])]
    
    train_set_small = Subset(dataset, sample_train_small)
    train_set_medium = Subset(dataset, sample_train_medium)  
    train_set_big = Subset(dataset, sample_train_big)  
    print(len(sample_train_small), len(sample_train_medium), len(sample_train_big))

    train_sampler_small = DistributedSampler(dataset=train_set_small)
    train_sampler_medium = DistributedSampler(dataset=train_set_medium)
    train_sampler_big = DistributedSampler(dataset=train_set_big)
    
    train_dataloader_small = DataLoader(
        dataset, sampler=train_sampler_small, batch_size=15, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    train_dataloader_medium = DataLoader(
        dataset, sampler=train_sampler_medium, batch_size=6, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    train_dataloader_big = DataLoader(
        dataset, sampler=train_sampler_big, batch_size=30, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    
    dataloaders_train = [train_dataloader_small, train_dataloader_medium, train_dataloader_big]
    
    train_dataloader_small_classical = DataLoader(
        dataset, sampler=train_sampler_small, batch_size=50, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    train_dataloader_medium_classical = DataLoader(
        dataset, sampler=train_sampler_medium, batch_size=50, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    train_dataloader_big_classical = DataLoader(
        dataset, sampler=train_sampler_big, batch_size=50, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    
    dataloaders_train_classical = [train_dataloader_small_classical, train_dataloader_medium_classical, train_dataloader_big_classical]
    
    

    sampler_classical = DistributedSampler(dataset=dataset, num_replicas=world_size)
    train_dataloader_classical = DataLoader(
        dataset, sampler=sampler_classical, batch_size=5, drop_last=False, collate_fn=collator_same_size, num_workers=world_size)
    test_dataloader = DataLoader(
        dataset, sampler=test_sample, batch_size=1, drop_last=False, collate_fn=collator_same_size)
    
    dataloaders_all = [train_dataloader_small, train_dataloader_medium, train_dataloader_big, test_dataloader]
    
    logs = []

    if loss == 'mse':
        loss_func = F.mse_loss
    else:
        loss_func = F.cross_entropy
    acc_test = []
    
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)
    
    
    ####### Tests ########
    
    ######################

    for epoch in range(epochs):
        t0 = time.perf_counter()
        batch_num = 0
        ddp_model.train()

        
        for layer in list(ddp_model.children())[0:-1]:
            for p in layer.attention.parameters():
                if epoch % quantum_every == 0:
                    p.requires_grad=True
                else:
                    p.requires_grad=False

        if epoch % quantum_every != 0:
            for i, dataloader in enumerate(dataloaders_train_classical):
                for batched_graph, labels, _, features, _ in dataloader:
                    labels = labels.to(local_dev)
                    attention1 = [feat['conv1'].clone().to(local_dev) for feat in features]
                    attention2 = [feat['conv2'].clone().to(local_dev) for feat in features]
                    nodes = batched_graph.batch_num_nodes()
                    pred = ddp_model(batched_graph.to(local_dev), batched_graph.ndata['feat'].clone().to(local_dev),
                                 ising_matrices=None, precomputed_attention1=attention1,
                                 precomputed_attention2=attention2)
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del pred
                    torch.cuda.empty_cache()

        else:
            for dataloader in dataloaders_train:
                for batched_graph, labels, ising, features, _ in dataloader:
                    tb0 = time.perf_counter()

                    labels = labels.to(local_dev)
                    observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                    if len(ising.shape) > 1:
                        tp0 = time.perf_counter()
                        #observables = [obs.unsqueeze(3) for obs in observables]
                        #observables = torch.cat(observables, axis=3)
                        tp1 = time.perf_counter()

                        ising = ising.to(local_dev)
                        batch_size = ising.shape[1]
                        tp2 = time.perf_counter()

                        pred = ddp_model(batched_graph.to(local_dev),
                                     batched_graph.ndata['feat'].to(local_dev),
                                     ising,
                                     unbatch=False,
                                     batch_size=batch_size,
                                     observables=None)
                        tp3 = time.perf_counter()
                        
                    else:
                        tp0 = time.perf_counter()
                        observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                        tp1 = time.perf_counter()
                        ising = [ising[i].to(local_dev) for i in range(len(ising))]
                        tp2 = time.perf_counter()

                        pred = ddp_model(batched_graph.to(local_dev), batched_graph.ndata['feat'].to(local_dev),
                                     ising, observables=None)
                        tp3 = time.perf_counter()
                    tp4 = time.perf_counter()
                       
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    tp5 = time.perf_counter()

                    optimizer.step()
                    tp6 = time.perf_counter()

                    del ising
                    del pred
                    torch.cuda.empty_cache()
                    batch_num += 1
                    tb1 = time.perf_counter()
                    if print_batch:
                        print("Epoch {} | Rank {} | Batch {} | F {} | B {} | O {}".format(epoch+1, rank, batch_num, tp3 - tp2, tp5-tp4, tp6-tp5))
                    tp7 = time.perf_counter()
                    #print(f'Rank {rank} |{tp1-tp0}|{tp2-tp1}|{tp3-tp2}|{tp5-tp4}|{tp6-tp5}|{tp7-tp6}')

                t1 = time.perf_counter()
            
            with torch.no_grad():
                for dataload in dataloaders_all:
                    for batched_graph, labels, ising, features, indices in dataload:
                        observables = [None] * len(labels) #[feat['obs'].to(device) for feat in features]
                        if len(ising.shape) > 1:
                            ising = ising.to(local_dev)
                           # observables = [obs.unsqueeze(3) for obs in observables]
                           # observables = torch.cat(observables, axis=3)
                            ising = ising.to(local_dev)
                            batch_size = ising.shape[1]
                            N = int(batched_graph.batch_num_nodes()[0])
                            mat = ddp_model.module.conv1.attention(N, ising, batch_size=batch_size, observables=None)
                            if len(mat.shape)==2:
                                mat = mat.unsqueeze(2).cpu()
                            for i, index in enumerate(indices):
                                dataset.update_graph_features(mat[:, :, i], index, 'conv1')
                            mat = ddp_model.module.conv2.attention(N, ising, batch_size=batch_size, observables=None)
                            if len(mat.shape)==2:
                                mat = mat.unsqueeze(2).cpu()
                            for i, index in enumerate(indices):
                                dataset.update_graph_features(mat[:, :, i], index, 'conv2')

                        else:
                            graphs = dgl.unbatch(batched_graph)
                            for ising_matrix, graph, obs, index in zip(ising, graphs, observables, indices):
                                ising_matrix = ising_matrix.to(local_dev)
                                N = graph.num_nodes()
                                mat = ddp_model.module.conv1.attention(N, ising_matrix, batch_size=1, observables=None)
                                dataset.update_graph_features(mat.cpu(), index, 'conv1')
                                mat = ddp_model.module.conv2.attention(N, ising_matrix, batch_size=1, observables=None)
                                dataset.update_graph_features(mat.cpu(), index, 'conv2')
        t2 = time.perf_counter()
        
        if rank == 0:
            with torch.no_grad():
                ddp_model.eval()
                num_correct = 0
                num_tests = 0
                loss_test = 0
                for batched_graph, labels, _, features, _ in test_dataloader:
                    labels = labels.to(local_dev)
                    attention1 = [feat['conv1'].clone().to(local_dev) for feat in features]
                    attention2 = [feat['conv2'].clone().to(local_dev) for feat in features]
                    pred = ddp_model(batched_graph.to(local_dev), batched_graph.ndata['feat'].clone().to(local_dev),
                                None, precomputed_attention1=attention1,
                                precomputed_attention2=attention2)
                    #num_correct += (pred.argmax(1) == labels).sum().item()
                    loss_test += loss_func(pred, labels)
                    num_tests += len(labels)
            #acc_test.append(num_correct/num_tests)
            acc_test = [0]

        if (epoch +1) % display_every == 0:
            if rank==0:
                print("Epoch {} | Time 1 {} | Time 2 {} | Acc test {} | Loss test {}".format(epoch+1, t1-t0, t2-t0, acc_test[-1], loss_test/num_tests))
    
    dist.destroy_process_group()


def training_loop_parallel(model, dataset, epochs, batch_size=10, train_size=.8, display_every=10, quantum_every=1, save_file=None, print_batch=False, loss='crossentropy'):
    world_size = 2
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1789'
    mp.spawn(training_local,
        args=(world_size, model, dataset, epochs, batch_size, train_size, display_every, quantum_every, save_file, print_batch, loss,),
        nprocs=world_size,
        join=True)
    
if __name__ == "__main__":
    max_node = 19
    graphs, targets = load_dataset('PTC_FM', min_node=4, max_node=max_node)
    graphs = [nx.convert_node_labels_to_integers(G) for G in graphs]
    graphs = np.array(graphs, dtype=object)
    dataset = NXDataset(graphs, targets, max_node, shuffle=True, seed=67, device='cpu')
    dataset.compute_obs_shortest_path(max_node, device='cpu')
    model = QGraphClassification(4, 64, 2, 4, obs, apply_softmax=True, only_neighbors=False)
    training_loop_parallel(None, dataset, epochs=10, batch_size=10, train_size=.8, display_every=1, quantum_every=1, save_file=None)
