from functools import partial

import wandb
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.optim import Optimizer
from torch_geometric.loader import DataLoader

def train(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer: Optimizer, 
    loss_func: nn.modules.loss._Loss,
    device: str,
) -> Tensor:
    model.train()
    loss_all = 0.
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        data = data.detach()
        optimizer.zero_grad()
        pred = model(data)

        assert 'label' in data
        assert data['label'].size() == pred.size()
        loss = loss_func(data['label'], pred)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)

def eval(
    model: nn.Module,
    loader: DataLoader,
    loss_func: nn.modules.loss._Loss,
    device: str,
) -> Tensor:
    model.eval()
    loss_all = 0.
    for batch_idx, data in enumerate(loader):
        data = data.to(device)
        data = data.detach()
        pred = model(data)

        assert 'label' in data
        assert data['label'].size() == pred.size()
        loss = loss_func(data['label'], pred)
        loss_all += loss.item() * data.num_graphs
    return loss_all / len(loader.dataset)

def timer(func, *args, **kwargs):
    import time
    start_time = time.perf_counter() 
    result = func(*args, **kwargs)
    end_time = time.perf_counter() 
    duration = end_time - start_time 
    return duration, result

def run_experiment(
    model: nn.Module, 
    train_loader: DataLoader, 
    valid_loader: DataLoader, 
    test_loader: DataLoader, 
    num_epochs: int, 
    optimizer: Optimizer, 
    loss_func: nn.modules.loss._Loss,
    eval_interval: int = 5,
    early_stop: float = float('inf'), 
    device: str = 'cpu'
) -> None:
    print(model)
    total_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
    wandb.log({'total_param': total_param})
    print(f'Number of Parameters: {total_param}')

    model = model.to(device)
   
    best_valid ={'epoch_idx': None, 'valid_loss': None, 'test_loss': None}
    eval_func = partial(timer, eval, model=model, loss_func=loss_func, device=device)

    for epoch_idx in range(num_epochs + 1):
        train_time, train_loss = timer(train, model, train_loader, optimizer, loss_func, device)
        wandb.log({
            'epoch': epoch_idx,
            'train_times': train_time,
            'train_loss': train_loss,
        })
        print(f'Epoch: {epoch_idx} | Train Loss: {train_loss :.6f}')

        if epoch_idx % eval_interval == 0:
            valid_time, valid_loss = eval_func(loader=valid_loader)
            test_time, test_loss = eval_func(loader=test_loader)
        
            if best_valid['valid_loss'] is None or valid_loss < best_valid['valid_loss']:
                best_valid['epoch_idx'] = epoch_idx
                best_valid['valid_loss'] = valid_loss
                best_valid['test_loss'] = test_loss

            wandb.log({
                'epoch': epoch_idx,
                'valid_times': valid_time, 'valid_loss': valid_loss,
                'test_times': test_time, 'test_loss': test_loss,
            })

            print(f'>>> Epoch: {epoch_idx} | Valid Loss: {valid_loss :.6f} | Test Loss: {test_loss :.6f}')
            print(f'*** Best Epoch: {best_valid["epoch_idx"]} | Valid Loss: {best_valid["valid_loss"] :.6f} | Test Loss: {best_valid["test_loss"] :.6f}')

            if epoch_idx - best_valid['epoch_idx'] >= early_stop:
                wandb.log({'early_stop_epoch': epoch_idx})
                print(f'Early stopped! Epoch: {epoch_idx}')
                break
    wandb.log({
        'best_epoch': best_valid["epoch_idx"],
        'best_test': best_valid["test_loss"],
    })
    return 