import torch
import numpy as np
from torchvision.utils import make_grid

from utils import n_way_classification

def train(model, loss_fn, train_dl, test_dl, params, writter, ds, verbose=1):
    '''
        Trains the given model using the given datasets

        :param model a torch.nn.Module containing the model to be trained
        :param loss_fn a loss function to optimize 
        :param train_dl a torch DataLoader containg the train images
        :param test_dl a torch DataLoader containing the test images
        :param params a dict containg parameters  
        :param writter a Tensorboard writter to log
        :param verbose verbosity level (0=print nothing, 1=print epoch progress)
    '''

    # Get some params 
    epochs = params['epochs']
    test_every = params["test_every"] 
    save_every = params["save_every"]
    lr = params["lr"]
    
    # Setup optimizer 
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # Iterate over epochs
    for epoch in range(epochs): 
        # Set the model to train
        model.train()
    
        train_loss = []
        for data in train_dl:
            x = data[0]
            y = data[1] if len(data) == 2 else x 

            # Get loss 
            loss, model_out = loss_fn(y, model, x)
            train_loss.append(loss.item())

            # Optimize 
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        writter.add_scalar("train/loss", np.mean(train_loss), epoch)

        # Check if it's time to save output
        if save_every and epoch > 0 and epoch % save_every == 0:   
            ims = torch.cat([y,model_out[0]], dim=-1)      
            if ds == "vim": 
                ims = ims[:,0,:,:]
                ims = torch.unsqueeze(ims, dim=1)
            grid = make_grid(ims, nrow=8)
            writter.add_image("train/images", grid, epoch)

        # Check if it's time to do a validation  
        # metric are computed at same freq
        test_loss = []
        if test_every > 0 and epoch > 0 and epoch % test_every == 0:
            # Compute train metrics first
            compute_metrics(y, model_out[0], writter, "train/", epoch)

            # Evaluate
            model.eval()            
            with torch.no_grad():
                for data in test_dl:
                    x = data[0]
                    y = data[1] if len(data) == 2 else x 
                    loss, model_out = loss_fn(y, model, x)
                    test_loss.append(loss.item())

            writter.add_scalar("test/loss", np.mean(test_loss), epoch)

            # Compute test metrics first
            compute_metrics(y, model_out[0], writter, "test/", epoch)

            # Check if it's time to save output
            if save_every and epoch > 0 and epoch % save_every == 0:
                ims = torch.cat([y,model_out[0]], dim=-1)     
                if ds == "vim":
                    ims = ims[:,0,:,:]
                    ims = torch.unsqueeze(ims, dim=1)
                grid = make_grid(ims, nrow=8)
                writter.add_image("test/images", grid, epoch)

        # Print progress, if verbose=1 
        if verbose:
            out = "Epoch = {}, Average train loss = {:.4f}".format(epoch, np.mean(train_loss))
            if test_loss:
                out += ", Average test loss = {:.4f}".format(np.mean(test_loss))
            print(out)

def compute_metrics(y, y_pred, writter, tb_path, epoch):
    y = y.cpu().detach().numpy()
    y_pred = y_pred.cpu().detach().numpy()

    acc_pearson_2 = n_way_classification(y, y_pred, n=2, metric="pearson")
    acc_pearson_5 = n_way_classification(y, y_pred, n=5, metric="pearson")
    acc_pearson_10 = n_way_classification(y, y_pred, n=10, metric="pearson")

    acc_ssim_2 = n_way_classification(y, y_pred, n=2, metric="ssim")
    acc_ssim_5 = n_way_classification(y, y_pred, n=5, metric="ssim")
    acc_ssim_10 = n_way_classification(y, y_pred, n=10, metric="ssim")

    writter.add_scalar(tb_path + "acc_pearson_2", acc_pearson_2, epoch)
    writter.add_scalar(tb_path + "acc_pearson_5", acc_pearson_5, epoch)
    writter.add_scalar(tb_path + "acc_pearson_10", acc_pearson_10, epoch)
    writter.add_scalar(tb_path + "acc_ssim_2", acc_ssim_2, epoch)
    writter.add_scalar(tb_path + "acc_ssim_5", acc_ssim_5, epoch)
    writter.add_scalar(tb_path + "acc_ssim_10", acc_ssim_10, epoch)

