import numpy as np
import os
import sys
import time
import torch
from torch.utils.tensorboard import SummaryWriter

import util


def train_sub(model, optimizer, dataloader, epoch, verbose=True):
    start_time = time.time()

    sum_loss = 0.0
    count = 0
    model.train()
    for (inputs, labels) in dataloader:
        inputs = util.try_gpu(inputs)
        labels = util.try_gpu(labels)

        # forward computation
        out = model(inputs)
        loss = model.loss_function(out, labels[:,0], labels[:,1], 'train')
        sum_loss += loss.item()
        # initialize gradient
        optimizer.zero_grad()
        # calculate gradient
        loss.backward()
        # update parameters
        optimizer.step()

        count += 1
        if verbose and count % 1000 == 0:
            elapsed_time = time.time() - start_time
            print('elapsed_time (sec.): {0:.2f}'.format(elapsed_time))
            print('batch count = %d, average train loss = %f' % (count,sum_loss/count), flush=True)
    if verbose:
        elapsed_time = time.time() - start_time
        print('elapsed_time (sec.): {0:.2f}'.format(elapsed_time))
        print('average train loss = %f' % (sum_loss/count))
    return sum_loss

def compute_loss(model, dataloader, mode, epoch = -1, verbose=True):
    start_time = time.time()
    sum_loss = 0.0
    count = 0
    model.eval()
    with torch.no_grad():
        for (inputs, labels) in dataloader:
            inputs = util.try_gpu(inputs)
            labels = util.try_gpu(labels)

            # forward computation
            out = model(inputs)
            loss = model.loss_function(out, labels[:,0], labels[:,1], mode)
            if isinstance(loss, dict):
                if isinstance(sum_loss, float):
                    sum_loss = {}
                    for key, value in loss.items():
                        sum_loss[key] = value
                else:
                    for key, value in loss.items():
                        sum_loss[key] += value
            else:
                sum_loss += loss.item()

            count += 1
            if verbose and isinstance(sum_loss,float) and count % 1000 == 0:
                elapsed_time = time.time() - start_time
                print('elapsed_time (sec.): {0:.2f}'.format(elapsed_time))
                print('batch count = %d, average %s loss = %f' % (count,mode,sum_loss/count), flush=True)
        if verbose and isinstance(sum_loss,float):
            elapsed_time = time.time() - start_time
            print('elapsed_time (sec.): {0:.2f}'.format(elapsed_time))
            print('average %s loss = %f' % (mode, sum_loss/count))
    if isinstance(sum_loss,float):
        return sum_loss / count
    else:
        for k, v in sum_loss.items():
            sum_loss[k] = v / count
        return sum_loss

def train(dataset, model, nn_param, logger_loss, verbose = True):
    # set logger
    if logger_loss is not None:
        log_dir = os.path.join(logger_loss.dir_name, logger_loss.model_name)
        logger = SummaryWriter(log_dir=log_dir)

    # create dataloaders
    train_dataloader = dataset.datamodule.train_dataloader()
    val_dataloader = dataset.datamodule.val_dataloader()

    # set optimizer
    learning_rate = nn_param.get('learning_rate', 0.001)
    optimizer = model.configure_optimizers()
    #optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    #optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)

    # main loop
    best_test_loss = sys.float_info.max
    best_iteration = -1
    epochs = nn_param.get('num_epoch', 1)
    for i in range(epochs):
        if verbose:
            print('epoch = %d/%d' % (i+1,epochs))

        if nn_param.get('predict_train', False):
            pt_dataloader = dataset.datamodule.predict_train_dataloader()
            model.loss_function.pred_train = predict(pt_dataloader, model)

        train_loss = train_sub(model, optimizer, train_dataloader, i)
        if logger_loss is not None:
            logger.add_scalar("train_loss", train_loss, i)

        #print('Perform test at the end of each epoch', flush=True)
        val_loss = 0.0
        count = 0
        val_loss = compute_loss(model, val_dataloader, 'val', i)
        if logger_loss is not None:
            logger.add_scalar("val_loss", val_loss, i)
        if nn_param.get('early_stopping', -1) > 0:
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_iteration = i
            if best_iteration + nn_param['early_stopping'] < i:
                print('Early stopping here', best_val_loss)
                break

    if logger_loss is not None:
        logger.close()
    return model

def validate(model, val_dataloader):
    return compute_loss(model, val_dataloader, 'val')

def test(model, test_dataloader):
    return compute_loss(model, test_dataloader, 'test')
    
def predict(dataloader, model):
    pred_list = []
    model.eval()
    with torch.no_grad():
        for inputs in dataloader:
            inputs = util.try_gpu(inputs)
            out = model(inputs)
            pred_list.append(out)
    y_pred = np.concatenate(pred_list)
    #print('shape of y_pred', y_pred.shape)

    return y_pred
