from typing import Optional

from sympy import use
import torch
from tensornvme import DiskOffloader
from loguru import logger
from utils.others import compute_micro_f1

def validate_test(epoch, model, loader, inductive: bool):
    r""" validation/test for inductive training
    """
    assert inductive, 'validate/test function must be called by inductive training'

    model.eval()
    out = model()

    n_correct = 0
    n_val_test = 0

    for i, (batch, batch_size, *args) in enumerate(loader):
        y = batch.y[:batch_size]
        n_val_test += batch_size
        n_correct += compute_micro_f1(out[i], y)*batch_size

    return n_correct/n_val_test

def train(epoch, model, loader, criterion, optimizer, inductive: Optional[bool] = False,
          storage_offload: Optional[bool] = False):
    model.train()

    optimizer.zero_grad()
    
    accelerator_outs, host_storage_tensors = model() # forward

    test_acc = 0.0
    val_acc = 0.0

    if not inductive:
        n_test_correct = 0
        n_test = 0
        n_val_correct = 0
        n_val = 0

    losses = []

    # double buffering
    pool_size = 2
    h2d_streams = list[torch.cuda.Stream]()
    for i in range(pool_size):
        h2d_streams.append(torch.cuda.Stream('cuda:0'))

    # resize the accelerater_outs
    accelerator_outs[0].untyped_storage().resize_(accelerator_outs[0].numel()*accelerator_outs[0].element_size())
    host_storage_tensors.async_upload(0, accelerator_outs[0], h2d_streams[0])

    total_loss = 0.0

    for i, (batch, batch_size, *args) in enumerate(loader):
        # logger.info(f'Current batch size: {batch_size}')

        y = batch.y[:batch_size].cuda()
        
        # load the activation to the accelerator
        host_storage_tensors.h2d_synchronize(h2d_streams[i%pool_size])
        torch.cuda.current_stream().wait_stream(h2d_streams[i%pool_size])

        if i != len(loader)-1:
            accelerator_outs[i+1].untyped_storage().resize_(accelerator_outs[i+1].numel()*accelerator_outs[i+1].element_size())
            host_storage_tensors.async_upload(i+1, accelerator_outs[i+1], h2d_streams[(i+1)%pool_size])
        
        if not inductive:
            train_mask = batch.train_mask[:batch_size]
            val_mask = batch.val_mask[:batch_size]
            test_mask = batch.test_mask[:batch_size]
            n_val += val_mask.sum().item()
            n_test += test_mask.sum().item()
            if val_mask.sum() != 0:
                n_val_correct += compute_micro_f1(accelerator_outs[i], y, val_mask)*val_mask.sum().item()
            if test_mask.sum() != 0:
                n_test_correct += compute_micro_f1(accelerator_outs[i], y, test_mask)*test_mask.sum().item()
            if train_mask.sum() == 0:
                del y, train_mask
                continue

        # as loss is not large to offload
        # we just keep it to the accelerator
        if not inductive:
            losses.append(criterion(accelerator_outs[i][train_mask], y[train_mask]))
        else:
            losses.append(criterion(accelerator_outs[i], y))
        
        # log losses
        total_loss += losses[-1].item()

        # then resize the accelerator_outs[i] as zero
        accelerator_outs[i].untyped_storage().resize_(0)

    if not inductive:
        test_acc = n_test_correct/n_test
        val_acc = n_val_correct/n_val

        logger.info(f'Epoch {epoch: 4d} | Loss: {total_loss/len(loader): 4f} | Val. Acc (%): {val_acc: .2%} | Test Acc (%): {test_acc: .2%}')

    model.backward(losses)

    # assert False, 'We need to implement the gradient update'

    optimizer.step()

    return dict(
        model=model,
        val_acc=val_acc,
        test_acc=test_acc
    )