import torch
from statistics import mean

from models import ResConvModel_via_matmul
from torch_geometric.loader import DataLoader
from utils import QM7DataWithResolvents
from datasets import QM9, ResolvDataset
from util.checkpoint_io import save_chkpt
from util.base_util import save_path


def train():
    model.train()

    loss_arr = []
    for data in train_loader:  # Iterate in batches over the training dataset.
        optimizer.zero_grad()  # Clear gradients.
        data = data.to(device)
        out = model(x = data.Z_one_hot, Z=data.Z, edge_index=data.r_edge_index, edge_attr=data.r_edge_attr, batch=data.batch)  # Perform a single forward pass.
        
        loss = criterion(out, data.y[:])  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        loss_arr.append(loss.item()) 

    return mean(loss_arr)
        

def test(loader):
     model.eval()
     loss_arr = []
     
     with torch.no_grad():
        for data in loader:  # Iterate in batches over the training/test dataset.
            data = data.to(device)
            out = model(x = data.Z_one_hot, Z=data.Z, edge_index=data.r_edge_index, edge_attr=data.r_edge_attr, batch=data.batch)  # Perform a single forward pass.
            
            loss = criterion(out, data.y[:])
            loss_arr.append(loss.item())

     return mean(loss_arr) 


if __name__ == "__main__":
    ######################## Checkpoint folder  ########################
    folder_out = save_path()
    print(f"Output folder: {folder_out}")
    chkpt_save_frequency = 100
    n_epochs = 1000
    ##################################################################

    ######################## Resolvent Parameters ########################
    omega = -1  #Expansion parameter 
    nf = 1  #Normalising Factor 
    ##################################################################

    ######################## Model Parameters ########################
    ModelShape= [64, 64]  # Shape of GNN
    K = 2  # Highest Resolvent exponent
    p = 1  # Aggregation norm
    input_dimension = 17  # Input dimensions
    ##################################################################

    ######################## Training Parameters ########################
    lr = 0.005  # Learning rate
    reg_lambda = 1e-3  # Weight decay Parameter  (# reg_lambda = 1e-4)
    patience = 20  # Early stopping Patience
    min_delta = 1  # Early stopping tolerance
    batch_size = 256  # Batch size
    test_set_size = 1500  # Number of samples in the test set
    #####################################################################

    ####### Random seed ################
    torch.manual_seed(1) 
    ####################################

    ######################## Dataset  ########################
    dataset = QM7DataWithResolvents(root='./data/Quamtum/QM7DataResolvents_with_Z', omega = omega, nf = nf) 
    ##############################################################################################


    ##### Dataset Statistics #########
    print()
    print(f'Dataset: {dataset.data}:')
    print('====================')
    print(f'Number of graphs: {len(dataset)}')
    print(f'Number of features: {dataset.num_features}')
    print(f'number of classes is {dataset.num_classes}') 
    ################################################################


    ############## Shuffle Dataset; Define TRAIN and TEST ##################
    print('====================')
    dataset = dataset.shuffle() 

    train_dataset = dataset[test_set_size:]
    test_dataset = dataset[:test_set_size]

    print(f'train_dataset: {train_dataset}')
    print(f'Number of training graphs: {len(train_dataset)}')
    print(f'Number of test graphs: {len(test_dataset)}')


    ############### Fix Batches ####################3
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    ################ Train AND Test ######################
    device='cuda'

    model = ResConvModel_via_matmul(input_dimension=input_dimension, hidden_channel_list=ModelShape, K_minus=K, p=p, zero_order=False, bias = True)

    model = model.to(device)
    print(f'Model: {model}')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg_lambda)

    criterion = torch.nn.L1Loss(reduction='mean')
    
    for epoch in range(n_epochs):
        train_loss = train()
        test_loss = test(test_loader)
        if epoch == n_epochs - 1 or epoch % chkpt_save_frequency == 0:
            save_chkpt(folder_out, model, epoch, optimizer)

        print(f'Epoch: {epoch:03d}, Train MAE: {train_loss:.4f}, Test MAE: {test_loss:.4f}')



