import torch
import wandb

import torch.nn as nn

from torch import Tensor
from torch.optim import Optimizer
from torch_geometric.loader import DataLoader

from models.diffusion import Diffusion

from tqdm import tqdm

from torch.nn.utils import clip_grad_norm_

__all__ = [
    'train_epoch',
    'val_epoch'
]


def train_epoch(
    model: Diffusion,
    train_loader: DataLoader, optimizer: Optimizer,
    epoch: int, device: torch.device
):
    model.train()
    pbar = tqdm(train_loader, leave=False, dynamic_ncols=True)
    lr = float('inf')
    for i, (_, y, _) in enumerate(pbar):
        optimizer.zero_grad()
        condition, func_value = y.to(device).squeeze(dim=3).permute(3, 0, 1, 2)
        loss_dict, _ = model.get_loss(func_value, condition)
        
        loss: Tensor = loss_dict['total']
        loss.backward()

        orig_grad_norm = clip_grad_norm_(model.parameters(), 2.0)
        if ((not orig_grad_norm.isnan()) and orig_grad_norm < 10.0) or epoch == 0:
            optimizer.step()


        train_loss = str(round(loss.item(), 7)).ljust(8)
        lr = "{:.4e}".format(optimizer.param_groups[0]['lr'])
        pbar.set_description(f'Training loss: {train_loss} lr: {lr}')

        try:
            wandb.log({
                'train/lr': optimizer.param_groups[0]['lr'],
                'train/grad_norm': orig_grad_norm,
                'custom_step': epoch
            } | {f'train/{k}': v.item() for k, v in loss_dict.items()})
        except:
            pass
    return float(lr) < 5e-7


def val_epoch(
    model: nn.Module,
    val_loader: DataLoader,
    epoch: int,
    device: torch.device
):
    sum_loss, sum_n = 0, 0
    model.eval()
    pbar = tqdm(val_loader, desc='Validation', leave=False, dynamic_ncols=True)
    with torch.no_grad():
        for (_, y, _) in pbar:
            condition, func_value = y.to(device).squeeze(dim=3).permute(3, 0, 1, 2)
            loss_dict, batch_size = model.get_loss(func_value, condition)

            loss: Tensor = loss_dict['score']
            sum_loss += loss.item() * batch_size
            sum_n += batch_size

    avg_loss = sum_loss / sum_n

    try:
        wandb.log({
            'val/loss': avg_loss,
            'custom_step': epoch
        })
    except:
        pass

    return avg_loss