from typing import List, Union
from logging import Logger

import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable

from .model import (
    print_loss_vae, 
    print_disentanglement_metrics, 
    save_epoch_model, 
    BaseCDVAE, 
    CDVAEDataParallel
)
from src.metrics import obtain_metrics

__all__ = [
    "train_epoch", 
    "test_epoch",
]


def train_epoch(
        logger: Logger,
        epoch: int, 
        n_epochs: int,
        n_iters_train: int,
        device: torch.device,
        model_dir: str,
        train_loader: DataLoader,
        model: Union[BaseCDVAE, CDVAEDataParallel],
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        save: bool=False,
        print_iters: int=1000,
    ):
    
    total_loss, recons_loss, kl_loss = 0, 0, 0
    model.train()

    for iter, (input, labels) in enumerate(train_loader, start=1):

        #======Data preparation=======
        input = Variable(input).to(
            device=device, dtype=torch.float, non_blocking=True
        )
        if isinstance(labels, List):
            for i in range(len(labels)):
                labels[i] = labels[i].to(device=device, non_blocking=True)
        else:
            labels = labels.to(device=device, non_blocking=True)

        #======Forward=======
        output = model(input)
        iter_loss = model.calc_loss(**output, x=input, labels=labels)
        total_loss += torch.mean(iter_loss['total_loss']).detach()
        kl_loss += torch.mean(iter_loss['kl_loss']).detach()
        recons_loss += torch.mean(iter_loss['recons_loss']).detach()

        #======Backward=======
        optimizer.zero_grad()
        torch.mean(iter_loss['total_loss']).backward()
        optimizer.step()

        #======Logs=======
        print_loss_vae(
            logger=logger,
            train=True, 
            epoch=epoch, 
            n_epochs=n_epochs, 
            iter=iter, 
            n_iters=n_iters_train, 
            kl_loss=kl_loss,
            recons_loss=recons_loss,
            total_loss=total_loss,
            print_iters=print_iters,
        )

    scheduler.step(epoch)

    save_epoch_model(
        epoch, model_dir, model, optimizer, scheduler, save
    )



def test_epoch(
        logger: Logger,
        epoch: int,
        n_epochs: int,
        n_iters_test: int,
        device: torch.device,
        test_loader: DataLoader,
        model: Union[BaseCDVAE, CDVAEDataParallel],
        print_iters: int=1000,
        calc_dis_metrics: bool=False,
        save_recon: bool=False,
    ):
    
    total_loss, recons_loss, kl_loss = 0, 0, 0
    test_mean, test_logvar, test_labels, test_recon = [], [], [], []
    model.eval()

    for iter, (input, labels) in enumerate(test_loader, start=1):

        #======Data preparation=======
        input = Variable(input).to(
            device=device, dtype=torch.float, non_blocking=True
        )
        if isinstance(labels, List):
            for i in range(len(labels)):
                labels[i] = labels[i].to(device=device, non_blocking=True)
        else:
            labels = labels.to(device=device, non_blocking=True)

        #======Forward=======
        with torch.no_grad():
            output = model(input, n_samples=0)
            iter_loss = model.calc_loss(**output, x=input, labels=labels)
            total_loss += torch.mean(iter_loss['total_loss']).detach()
            kl_loss += torch.mean(iter_loss['kl_loss']).detach()
            recons_loss += torch.mean(iter_loss['recons_loss']).detach()
            test_mean.extend(output['mean'].cpu().detach())
            if isinstance(labels, torch.Tensor):
                labels = labels.unsqueeze(1)
            else:
                labels = torch.stack(labels).T
            test_labels.extend(labels.cpu())
            if save_recon:
                test_recon.extend(output['x_hat'].cpu().detach())

        #======Logs=======
        print_loss_vae(
            logger=logger,
            train=False, 
            epoch=epoch, 
            n_epochs=n_epochs, 
            iter=iter, 
            n_iters=n_iters_test, 
            kl_loss=kl_loss,
            recons_loss=recons_loss,
            total_loss=total_loss,
            print_iters=print_iters,
        )
    
    test_mean = torch.stack(test_mean).numpy()
    test_labels = torch.stack(test_labels).numpy()
    
    if calc_dis_metrics:
        dis_metrics = obtain_metrics(test_labels[:1000], test_mean[:1000])
        print_disentanglement_metrics(logger, dis_metrics)

    if save_recon:
        test_recon = torch.stack(test_recon).numpy()

    return {
        'recon':    test_recon, 
        'labels':   test_labels,
        'mean':     test_mean, 
        'logvar':   test_logvar
    }