#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
import numpy as np
from modules import PredsegModule


def train(module, data_loader, folder, lr=0.005, n_epoch=1, max_samples=np.inf,
          device='cpu', start_epoch=0, pars_train='pred', loss_type='shuffle',
          momentum=0, weight_decay=10**-4, n_report=10000, n_acc=5,
          lr_fact=10, neigh=8, **kwargs):
    """
    trains the spatial prediction for a given module saving the parameters
    into (sub-folders of) folder.
    This version in principle trains all parameters in the module.
    There are diverse additional parameters for the training.
    additional keyword arguments are passed to the loss function.

    Parameters
    ----------
    module : the predseg module to be trained
    data_loader : torch.data.Dataset
        should return a batch of images to be passed into the network and
        nothing else
    folder : str, folder name
        where to save logs and checkpoints

    Returns
    -------
    None.

    """
    # write parameters to hparms.json
    with open(os.path.join(folder, 'hparam.json'), 'a') as f:
        hparams = {
            'lr': lr,
            'lr_fact': lr_fact,
            'start_epoch': start_epoch,
            'n_epoch': n_epoch,
            'max_samples': max_samples,
            'device': device,
            'pars_train': pars_train,
            'loss_type': loss_type,
            'momentum': momentum,
            'neigh': neigh,
            'weight_decay': weight_decay,
            'folder_data': data_loader.dataset.folder}
        hparams.update(kwargs)
        json.dump(hparams, f, indent=2)
    module.to(device)
    losses = []
    pars = module.get_pred_pars()
    p_not = module.get_other_pars()
    if pars_train == 'pred':
        for p in p_not:
            p.requires_grad = False
        optimizer = torch.optim.SGD(pars, lr=lr, momentum=momentum,
                                    weight_decay=weight_decay)
    elif pars_train == 'other':
        for p in pars:
            p.requires_grad = False
        optimizer = torch.optim.SGD(p_not, lr=lr, momentum=momentum,
                                    weight_decay=weight_decay)
    else:
        optimizer = torch.optim.SGD([
            {'params': pars, 'lr': lr_fact*lr},
            {'params': p_not, 'lr': lr}],
             momentum=momentum, weight_decay=weight_decay)
    if not os.path.isdir(os.path.join(folder, 'checkpoints')):
        os.mkdir(os.path.join(folder, 'checkpoints'))
    loss_file = os.path.join(folder, 'checkpoints', 'losses_%d.npy')
    cp_file = os.path.join(folder, 'checkpoints', 'cp_%d.pth')
    tb_writer = SummaryWriter(os.path.join(folder, 'tensorboard'))
    final_file = os.path.join(folder, 'pars.pth')
    k = start_epoch * len(data_loader) * data_loader.batch_size
    last_report = -np.inf
    for i_epoch in range(start_epoch, start_epoch + n_epoch):
        try:
            t = tqdm(iter(data_loader))
            for im in t:
                im = im.to(device)
                loss = train_step(module, im, optimizer, loss_type, n_acc, **kwargs)
                tb_writer.add_scalar('Loss/train', loss, k)
                losses.append(loss)
                t.set_postfix({'loss': losses[-1],
                               'ep_mean': np.mean(losses)})
                k += im.shape[0]
                if k >= max_samples:
                    break
                if k - last_report >= n_report:
                    last_report = k
                    tb_writer.add_images('input', im.cpu().detach().numpy().astype('uint8'), k)
                    m_idx = 0
                    for m_idx, m in enumerate(module.p_modules()):
                        tb_writer.add_histogram('log_c/%d' % m_idx, m.log_c.flatten().cpu().detach(), k)
                        tb_writer.add_histogram('prior_w/%d' % m_idx, m.prior_w.flatten().cpu().detach(), k)
                    grad_all = []
                    for p in pars:
                        if p.grad is not None:
                            grad_all.append(p.grad.cpu().flatten().detach().numpy())
                    if len(grad_all) > 0:
                        grad_all = np.concatenate(grad_all)
                        tb_writer.add_histogram('grad/pars', grad_all, k,
                                                bins=np.linspace(-1, 1, 200))
                        tb_writer.add_scalar('grad0/pars', np.sum(grad_all == 0) / len(grad_all), k)
                    grad_all = []
                    for p in p_not:
                        if p.grad is not None:
                            grad_all.append(p.grad.cpu().flatten().detach().numpy())
                    if len(grad_all) > 0:
                        grad_all = np.concatenate(grad_all)
                        tb_writer.add_histogram('grad/p_not', grad_all, k,
                                                bins=np.linspace(-1, 1, 200))
                        tb_writer.add_scalar('grad0/p_not',
                                             np.sum(grad_all == 0) / len(grad_all), k)
        finally:
            # logging & checkpoint
            t.close()
            np.save(loss_file % i_epoch, np.array(losses))
            torch.save(module.state_dict(), cp_file % i_epoch)
            losses = []
        if k >= max_samples:
            break
    torch.save(module.state_dict(), final_file)


def train_step(module, images, optimizer, loss_type, n_acc=0, **kwargs):
    """ training step
    n_acc dictates how many random draws of the loss are accumulated before passing the gradient
    through. It can be 1 for testing purposes. It then runs the code for accumulation, but with
    only one repeat.
    """
    for m in module.p_modules():
        m.feat = None
    optimizer.zero_grad()
    _ = module(images)
    if n_acc >= 1:
        l_report = module.grad_acc(loss_type, n_acc=n_acc, **kwargs)
    else:
        loss = module.get_loss(loss_type, **kwargs)
        loss.backward()
        l_report = loss.detach().cpu().numpy()
    # gradient clipping to avoid initial parameter explosion
    for p_g in optimizer.param_groups:
        torch.nn.utils.clip_grad_value_(p_g['params'], 1)
    optimizer.step()
    return l_report
