import torch
import wandb

import torch.nn as nn

from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from tqdm import tqdm

from torch.nn.utils import clip_grad_norm_

__all__ = [
    'train_epoch',
    'val_epoch'
]


def train_epoch(model: nn.Module,
                train_loader: DataLoader, optimizer: Optimizer,
                epoch: int, device: torch.device, print_or_not: bool = True):
    model.train()
    if print_or_not:
        pbar = tqdm(train_loader, leave=False, dynamic_ncols=True)
    else:
        pbar = train_loader

    lr = float('inf')
    for i, (batch, ) in enumerate(pbar):
        optimizer.zero_grad()
        loss_dict, _ = model.get_loss(batch.to(device))
        loss: Tensor = loss_dict['total']
        loss.backward()
        orig_grad_norm = clip_grad_norm_(model.parameters(), 2.0)
        if ((not orig_grad_norm.isnan()) and orig_grad_norm < 10.0) or epoch == 0:
            optimizer.step()

        lr = "{:.4e}".format(optimizer.param_groups[0]['lr'])
        if print_or_not and i % 5 == 0:
            train_loss = str(round(loss.item(), 7)).ljust(8)
            pbar.set_description(f'Training loss: {train_loss} lr: {lr}')

        try:
            wandb.log({
                'train/lr': optimizer.param_groups[0]['lr'],
                'train/grad_norm': orig_grad_norm,
                'custom_step': epoch
            } | {f'train/{k}': v.item() for k, v in loss_dict.items()})
        except:
            pass
    return float(lr) < 5e-7


def val_epoch(model, val_loader, epoch, device, print_or_not):
    sum_loss, sum_n = 0, 0
    model.eval()
    if print_or_not:
        pbar = tqdm(val_loader, desc='Validation', leave=False, dynamic_ncols=True)
    else:
        pbar = val_loader
    with torch.no_grad():
        for (batch, ) in pbar:
            loss_dict, batch_size = model.get_loss(batch.to(device))
            loss: Tensor = loss_dict['score']
            sum_loss += loss.item() * batch_size
            sum_n += batch_size

    avg_loss = sum_loss / sum_n

    try:
        wandb.log({
            'val/loss': avg_loss,
            'custom_step': epoch
        })
    except:
        pass

    return avg_loss