import torch
import torch.nn as nn
import networkx as nx
import numpy as np
from torch.utils.data.sampler import SequentialSampler
import dgl
from dgl.dataloading import GraphDataLoader
from dgl.data import DGLDataset

from torch.utils.data import DataLoader
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import time
import torch.nn.init as init

from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset
import torch.optim as optim
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from dgl.dataloading import GraphCollator

import wandb
import uuid

from dataset import load_dataset, collator_same_size, CustomDataset, CustomDatasetClassical
from training import training_loop_quantum, training_loop
from models.quantum import QGraphClassification, CGraphRegression, GCNRegression, QGraphNetworkCustom, HybridModel
from utils import generate_ising_matrices_torch, obs_ZZ

import os
import time
import sys
import pickle

def return_dataloaders(dataset_list, dataset_classical_list, batch_sizes, world_size=None):
    dataloader_list = []
    dataloader_classical_list = []
    assert len(dataset_list) == len(batch_sizes)

    num_items = 0

    for dataset, dataset_classical, batch_size in zip(dataset_list, dataset_classical_list, batch_sizes):

        if world_size is not None:
            sampler = DistributedSampler(dataset=dataset)
            sampler_classical = DistributedSampler(dataset=dataset_classical)
            dataloader = DataLoader(
                dataset, sampler=sampler, batch_size=int(batch_size), drop_last=False, shuffle=False, num_workers=world_size)

            dataloader_classical = DataLoader(
                dataset_classical, sampler=sampler_classical, batch_size=3000, drop_last=False, shuffle=False, num_workers=world_size)
        else:
            dataloader = DataLoader(
                dataset, batch_size=int(batch_size), drop_last=False, shuffle=False)

            dataloader_classical = DataLoader(
                dataset_classical, batch_size=3000, drop_last=False, shuffle=False)

        dataloader_list.append(dataloader)
        dataloader_classical_list.append(dataloader_classical)
        num_items += len(dataset)

    return dataloader_list, dataloader_classical_list, num_items

def compute_attention_matrices_single(model, dataloader, dataset_classical, device='cpu', verbose=False):
    with torch.no_grad():
        for features, labels, ising, adj, idx in dataloader:
            for i, layer in enumerate(model.layers):
                for j, head in enumerate(layer.heads):
                    if verbose:
                        print(i, j)
                    mat = head.attention(features.shape[1], ising.to(device), batch_size=labels.shape[0])
                    if labels.shape[0]>1:
                        mat = mat.permute(2, 0, 1)
                    dataset_classical.update_attention_matrices(mat, (i, j), idx)





def training_local(rank, world_size, model, dataset_list, dataset_classical_list,
                   epochs, batch_sizes_train, batch_sizes_val=None, batch_sizes_test=None, quantum_every=4,
                   dataset_val_list=None, dataset_val_classical_list=None,
                   dataset_test_list=None, dataset_test_classical_list=None, lr=.001, gamma=1., dataset_name='', wb=False, config=None, file_output=None,
                   loss_func=None, loss_func_test=None):
    """Performs the training loop for a single process.
    Parameters:
    ----------
    rank : int
        Rank of the process.
    world_size : int
        Number of processes.
    model : nn.Module
        The model to train.
    dataset_list : list of datasets to train on
    dataset_classical_list : list of datasets for the classical part to train on
    epochs : int
        Number of epochs to train for.
    batch_sizes : list of int
        Batch sizes to use for each dataset.
    quantum_every : int
        Number of epochs to train for before training the quantum part."""

    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    local_dev = 'cuda:'+str(rank)
    model = model.to(local_dev)
    ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)


    dataloader_list, dataloader_classical_list, num_train = return_dataloaders(dataset_list, dataset_classical_list, batch_sizes_train, world_size)
    
    if dataset_val_list is not None:
        dataloader_val_list, dataloader_val_classical_list, num_val = return_dataloaders(dataset_val_list, dataset_val_classical_list, batch_sizes_val, None)

    if dataset_test_list is not None:
        dataloader_test_list, dataloader_test_classical_list, num_test = return_dataloaders(dataset_test_list, dataset_test_classical_list, batch_sizes_test, None)

    
    print('Process ', str(rank))

    optimizer = optim.Adam(ddp_model.parameters(), lr=lr)

    ddp_model.train()
    time_list = []
        
    loss_val_list, loss_test_list = [], []

    if loss_func is None:
        loss_func = F.mse_loss
    if loss_func_test is None:
        loss_func_test = F.l1_loss

    if wb:
        if rank==0:
            wandb.init(
                project="QGNN",
                entity="qgnn",
                name=config['name'],
                config=config
            )
            wandb.watch(ddp_model)

    for epoch in range(1, epochs+1):
        t0 = time.perf_counter()
        if epoch % quantum_every == 1:
            for dataloader, dataset_classical in zip(dataloader_list, dataset_classical_list):
                for features, labels, ising, adj, _ in dataloader:
                    labels = labels.to(local_dev)
                    pred = ddp_model(features.to(local_dev), ising.to(local_dev), adj.to(local_dev), labels.shape[0])
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del labels
                    del pred
                    torch.cuda.empty_cache()
                compute_attention_matrices_single(ddp_model.module, dataloader, dataset_classical, local_dev)
                if epoch == 1:
                    print(dataset_classical.n_nodes)


        else:
            for dataloader_classical in dataloader_classical_list:
                for features, labels, adj, attention, _ in dataloader_classical:
                    labels = labels.to(local_dev)
                    pred = ddp_model(features.to(local_dev), None, adj.to(local_dev), labels.shape[0], attention.to(local_dev))
                    loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del labels
                    del pred
                    torch.cuda.empty_cache()

        t1 = time.perf_counter()


        if rank == 0:
            loss_val, loss_test = 0, 0
            if dataset_val_list is not None:
                with torch.no_grad():
                    if epoch % quantum_every == 1:
                        for dataloader, dataset_classical in zip(dataloader_val_list, dataset_val_classical_list):
                            compute_attention_matrices_single(ddp_model.module, dataloader, dataset_classical, local_dev)

                    for dataloader_classical in dataloader_val_classical_list:
                        for features, labels, adj, attention, _ in dataloader_classical:
                            labels = labels.to(local_dev)
                            batch_size = labels.shape[0]
                            pred = ddp_model(features.to(local_dev), None, adj.to(local_dev), batch_size, attention.to(local_dev))
                            loss_val += loss_func_test(pred, labels, reduction='sum').cpu() / num_val

            if dataset_test_list is not None:
                with torch.no_grad():
                    if epoch % quantum_every == 1:
                        for dataloader, dataset_classical in zip(dataloader_test_list, dataset_test_classical_list): 
                            compute_attention_matrices_single(ddp_model.module, dataloader, dataset_classical, local_dev)
                    for dataloader_classical in dataloader_test_classical_list:
                        for features, labels, adj, attention, _ in dataloader_classical:
                            labels = labels.to(local_dev)
                            batch_size = labels.shape[0]
                            pred = ddp_model(features.to(local_dev), None, adj.to(local_dev), batch_size, attention.to(local_dev))
                            loss_test += loss_func_test(pred, labels, reduction='sum').cpu() / num_test
            loss_val_list.append(loss_val)
            loss_test_list.append(loss_test)
            time_list.append(t1-t0)


            logs = {
            'epoch': epoch,
            'loss_val': loss_val,
            'loss_test': loss_test,
            'time': t1-t0,
            }
            if wb:
                wandb.log(logs)
            np.save(f'{file_output}/time.npy', np.array(time_list))
            np.save(f'{file_output}/loss_val.npy', np.array(loss_val_list))
            np.save(f'{file_output}/loss_test.npy', np.array(loss_test_list))
            torch.save(ddp_model.module.state_dict(), f'{file_output}/parameters/model_{epoch}.pt')


            print("Epoch {} | Time {} | Loss {} | Val {} | Test {} ".format(epoch, t1-t0, loss.detach().cpu(), loss_val, loss_test))

    if wb:
        if rank==0:
            wandb.finish()
    dist.destroy_process_group()


def training_loop_parallel(model, dataset_list, dataset_classical_list, epochs, world_size, batch_sizes_train, batch_sizes_val=None, batch_sizes_test=None, quantum_every=4,
                           dataset_val_list=None, dataset_val_classical_list=None,
                           dataset_test_list=None, dataset_test_classical_list=None, lr=.001, gamma=1., dataset_name='',
                           wb=False, cli_args=None, loss_func=None, loss_func_test=None):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1876'

    exp_id = uuid.uuid4().hex[:10]

    name = f"QGNN_{dataset_name}"
    config={    
                "name": f"{name}_{exp_id}",
                "lr": lr,
                "gamma": gamma,
                "n_epochs": epochs,
                "dataset": dataset_name,
                #"batch_size": train_dataloader.dataloader.batch_size,
                "model": 'QGNN',
                "args": model.args,
                "cli_args": cli_args,
            }
    os.mkdir(f'results/{name}_{exp_id}')
    os.mkdir(f'results/{name}_{exp_id}/parameters')

    with open(f'results/{name}_{exp_id}/config.pickle', 'wb') as f:
        pickle.dump(config, f)

    # if wb:
    #     wandb.init(
    #         project="QGNN",
    #         entity="qgnn",
    #         name=f"{name}_{exp_id}",
    #         config=config
    #     )
    #     wandb.watch(model)

    mp.spawn(training_local,
        args=(world_size, model, dataset_list, dataset_classical_list, epochs, batch_sizes_train, batch_sizes_val, batch_sizes_test, quantum_every, dataset_val_list,
                dataset_val_classical_list, dataset_test_list, dataset_test_classical_list, lr, gamma, dataset_name, wb, config, f'results/{name}_{exp_id}', loss_func, loss_func_test),
        nprocs=world_size,
        join=True)
    # if wb:
    #     wandb.finish()


def training_loop_single(model, dataset_list, dataset_classical_list, epochs, batch_sizes_train, batch_sizes_val=None, batch_sizes_test=None, device='cpu' ,quantum_every=4, dataset_val_list=None, dataset_val_classical_list=None,
                   dataset_test_list=None, dataset_test_classical_list=None, lr=.001, gamma=1., dataset_name='', wb=False, cli_args=None, seed=0, loss_func=None, loss_func_test=None, save_loss_train=True):

    # Initialize W&B tracking
    # start_time = datetime.today().strftime('%y-%m-%d_%H:%M:%S')
    exp_id = uuid.uuid4().hex[:10]

    if isinstance(model, HybridModel):
        name = f"Hybrid_{dataset_name}"
    else:
        name = f"QGNN_{dataset_name}"

    config={
                "lr": lr,
                "gamma": gamma,
                "n_epochs": epochs,
                "dataset": dataset_name,
                #"batch_size": train_dataloader.dataloader.batch_size,
                "model": 'QGNN',
                "args": model.args,
                "seed": seed,
                "cli_args": cli_args
            }
    file_output = f'results/{name}_{exp_id}'
    # os.mkdir(file_output)
    # os.mkdir(f'{file_output}/parameters')

    # with open(f'results/{name}_{exp_id}/config.pickle', 'wb') as f:
    #     pickle.dump(config, f)

    if wb:
        wandb.init(
            project="QGNN",
            entity="qgnn",
            name=f"{name}_{exp_id}",
            config=config
        )
        wandb.watch(model)

    device = torch.device(device)
    model = model.to(device)

    dataloader_list = []
    dataloader_classical_list = []
    time_list = []


    dataloader_list, dataloader_classical_list, num_train = return_dataloaders(dataset_list, dataset_classical_list, batch_sizes_train)

    if dataset_val_list is not None:
        dataloader_val_list, dataloader_val_classical_list, num_val = return_dataloaders(dataset_val_list, dataset_val_classical_list, batch_sizes_val)

    if dataset_test_list is not None:
        dataloader_test_list, dataloader_test_classical_list, num_test = return_dataloaders(dataset_test_list, dataset_test_classical_list, batch_sizes_test)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    if loss_func is None:
        loss_func = F.l1_loss
    if loss_func_test is None:
        loss_func_test = F.l1_loss

    print(len(dataloader_list))
    loss_val_list, loss_test_list = [], []
    
    for epoch in range(epochs):
        t0 = time.perf_counter()
        if epoch % quantum_every == 0:
            for dataloader, dataset_classical in zip(dataloader_list, dataset_classical_list):
                for features, labels, ising, adj, _ in dataloader:
                    pred = model(features.to(device), ising.to(device), adj.to(device), labels.shape[0])
                  #  print(features.shape, labels.shape, ising.shape, adj.shape)
                    
                    labels = labels.to(device)
                    try:
                        loss = loss_func(pred, labels)
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func(pred, labels)
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    del labels
                    del pred
                    torch.cuda.empty_cache()
                
                compute_attention_matrices_single(model, dataloader, dataset_classical, device)
                
                if epoch == 1:
                    print(dataset_classical.n_nodes)

        else:
            for dataloader_classical in dataloader_classical_list:
                for features, labels, adj, attention, _ in dataloader_classical:
                    pred = model(features.to(device), None, adj.to(device), labels.shape[0], attention.to(device))
                    labels = labels.to(device)
                    try:
                        loss = loss_func(pred, labels)
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func(pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del labels
                    del pred
                    torch.cuda.empty_cache()
        t1 = time.perf_counter()
        time_list.append(t1-t0)

        loss_train, loss_val, loss_test = 0, 0, 0

        if save_loss_train:
            for dataloader_classical in dataloader_val_classical_list:
                for features, labels, adj, attention, _ in dataloader_classical:
                    labels = labels.to(device)
                    batch_size = labels.shape[0]
                    pred = model(features.to(device), None, adj.to(device), batch_size, attention.to(device))
                    try:
                        loss = loss_func(pred, labels, reduction='sum')
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func(pred, labels, reduction='sum')
                    loss_train += loss.cpu() / num_train


        if dataset_val_list is not None:
            with torch.no_grad():
                if epoch % quantum_every == 0:
                    for dataloader, dataset_classical in zip(dataloader_val_list, dataset_val_classical_list):
                        compute_attention_matrices_single(model, dataloader, dataset_classical, device)

                for dataloader_classical in dataloader_val_classical_list:
                    for features, labels, adj, attention, _ in dataloader_classical:
                        labels = labels.to(device)
                        batch_size = labels.shape[0]
                        pred = model(features.to(device), None, adj.to(device), batch_size, attention.to(device))
                        try:
                            loss = loss_func_test(pred, labels, reduction='sum')
                        except RuntimeError as e:
                            labels = labels.reshape((-1,))
                            loss = loss_func_test(pred, labels, reduction='sum')
                        loss_val += loss.cpu() / num_val
                loss_val_list.append(loss_val)

        if dataset_test_list is not None:
            with torch.no_grad():
                if epoch % quantum_every == 0:
                    for dataloader, dataset_classical in zip(dataloader_test_list, dataset_test_classical_list): 
                        compute_attention_matrices_single(model, dataloader, dataset_classical, device)
                for dataloader_classical in dataloader_test_classical_list:
                    for features, labels, adj, attention, _ in dataloader_classical:
                        labels = labels.to(device)
                        batch_size = labels.shape[0]
                        pred = model(features.to(device), None, adj.to(device), batch_size, attention.to(device))
                        try:
                            loss = loss_func_test(pred, labels, reduction='sum')
                        except RuntimeError as e:
                            labels = labels.reshape((-1,))
                            loss = loss_func_test(pred, labels, reduction='sum')
                        loss_test += loss.cpu() / num_test
                loss_test_list.append(loss_test)
                

        # Collect history in W&B
        logs = {
            'epoch': epoch+1,
            'loss_val': loss_val,
            'loss_test': loss_test,
            'loss_train': loss_train,
            'true_loss_train': loss_train,
            'time': t1-t0,
        }
        if wb:
            wandb.log(logs)
        # np.save(f'{file_output}/time.npy', np.array(time_list))
        # np.save(f'{file_output}/loss_val.npy', np.array(loss_val_list))
        # np.save(f'{file_output}/loss_test.npy', np.array(loss_test_list))
      #  np.save('time_list_single.npy', np.array(time_list))
      #  torch.save(model.state_dict(), f'results/{name}_{exp_id}/parameters/model_{epoch}.pt')
        print("Epoch {} | Time {} | Loss {} | Val {} | Test {} ".format(epoch, t1-t0, loss.detach().cpu(), loss_val, loss_test))
    if wb:
        wandb.finish()
