import torch
import wandb

import torch.nn as nn

from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from models.diffusion import Diffusion

from tqdm import tqdm

from torch.nn.utils import clip_grad_norm_

__all__ = [
    'train_epoch',
    'val_epoch'
]


def train_epoch(
    model: Diffusion,
    train_loader: DataLoader, optimizer: Optimizer,
    epoch: int, device: torch.device
):
    model.train()
    pbar = tqdm(train_loader, leave=False, dynamic_ncols=True)
    lr = float('inf')
    for (x0, ) in pbar:
        optimizer.zero_grad()
        loss_dict, _ = model.get_loss(x0.to(device))
        
        loss: Tensor = loss_dict['total']
        loss.backward()

        orig_grad_norm = clip_grad_norm_(model.parameters(), 100.0)
        if not orig_grad_norm.isnan():
            optimizer.step()


        train_loss = str(round(loss.item(), 7)).ljust(8)
        lr = "{:.4e}".format(optimizer.param_groups[0]['lr'])
        pbar.set_description(f'Training loss: {train_loss} lr: {lr}')

        try:
            wandb.log({
                'train/lr': optimizer.param_groups[0]['lr'],
                'train/grad_norm': orig_grad_norm,
                'custom_step': epoch
            } | {f'train/{k}': v.item() for k, v in loss_dict.items()})
        except:
            pass
    return float(lr) < 5e-7


def val_epoch(
    model: nn.Module,
    val_loader: DataLoader,
    epoch: int,
    device: torch.device
):
    sum_loss, sum_n = 0, 0
    model.eval()
    pbar = tqdm(val_loader, desc='Validation', leave=False, dynamic_ncols=True)
    with torch.no_grad():
        for (x0, ) in pbar:
            loss_dict, batch_size = model.get_loss(x0.to(device))

            loss: Tensor = loss_dict['score']
            sum_loss += loss.item() * batch_size
            sum_n += batch_size

    avg_loss = sum_loss / sum_n

    try:
        wandb.log({
            'val/loss': avg_loss,
            'custom_step': epoch
        })
    except:
        pass

    return avg_loss