import torch
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from time import time as default_timer
import tqdm
import os
import gc
import wandb
from  .training_handler import *



def train_classification(*,
    model,
    train_loader,
    val_loader,
    test_loader,
    params,
):

    log_interval = params.wandb_log_interval
    wandb_log = params.wandb_log
    save_model = params.save_model

    lr = params.lr
    weight_decay = params.weight_decay
    epochs = params.epochs
    

    optimizer = get_optimizer(params)(model.parameters(), lr=lr, weight_decay=weight_decay, **params.optimizer_kwargs)
    scheduler = get_scheduler(params)(optimizer, **params.scheduler_kwargs)
    loss_function = get_loss(params)

    device = params.device

    model = model.to(device)

    for ep in range(epochs):
        model.train()
        t1 = default_timer()
        train_loss = 0
        train_acc = 0
        train_steps = 0
        train_loader_iter = tqdm.tqdm(
            train_loader,
            desc=f'Epoch {ep}/{epochs}',
            leave=False,
            ncols=100
        )
        for data in train_loader_iter:

            optimizer.zero_grad()
            x, y = data[0], data[1]
            x, y = x.to(device), y.to(device)
            yp = model(x)
            loss = loss_function(yp, y).mean()
            loss.backward()

            if params.clip_gradient:
                nn.utils.clip_grad_value_(
                    model.parameters(), params.gradient_clip_value)
            optimizer.step()

            train_acc += (yp.argmax(dim=1) == y).float().mean().item()

            train_loss += loss.item()
            train_steps += 1
            del x, y, loss
            gc.collect()

        torch.cuda.empty_cache()
        avg_train_l2 = train_loss / train_steps
        avg_train_acc = train_acc / train_steps

        if params.scheduler == 'plateau':
            scheduler.step(avg_train_l2)
        
        t2 = default_timer()
        epoch_train_time = t2 - t1

        if ep % log_interval == 0:
            values_to_log = dict(train_err=avg_train_l2,
                                train_acc=avg_train_acc,
                                 time=epoch_train_time,
                                 c_lr=optimizer.param_groups[0]['lr'])
            print(f"Epoch {ep}: "
                  f"Time: {epoch_train_time:.2f}s, "
                  f"Loss: {avg_train_l2:.4f}"
                  f"Accuracy: {avg_train_acc:.4f}")
            if wandb_log:
                wandb.log(values_to_log, commit=True)

        if ep % params.weight_saving_interval == 0 or ep == epochs - 1:
            if save_model:
                weight_path_model = os.path.join(params.save_model_path, f"{ep} {params.save_model_name}")
                torch.save(model.state_dict(), weight_path_model)

