
from __future__ import print_function, absolute_import, division

import time

import torch
import torch.nn as nn

from progress.bar import Bar
from utils.utils import AverageMeter, lr_decay

'''
Code are modified from https://github.com/garyzhao/SemGCN
This train function is adopted from SemGCN for baseline training.
'''

def train(data_loader, model_pos, criterion, optimizer, device, lr_init, lr_now, step, decay, gamma, num_branches, max_norm=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    epoch_loss_3d_pos = AverageMeter()

    # Switch to train mode
    torch.set_grad_enabled(True)
    model_pos.train()
    end = time.time()

    tot_poses = 0

    bar = Bar('Train', max=len(data_loader))
    for i, temp in enumerate(data_loader):
        # Measure data loading time
        data_time.update(time.time() - end)

        inputs_2d, targets_3d = temp[1], temp[0]
        num_poses = targets_3d.size(0)
        tot_poses += num_poses

        targets_3d = targets_3d[:, :, :] - targets_3d[:, :1, :]  # the output is relative to the 0 joint
        copy_targets_3d = []
        for ck in range(num_branches):
            copy_targets_3d.append(targets_3d.clone())
        targets_3d = torch.stack(copy_targets_3d, dim=0)       # [5, #b, 16, 3]

        step += 1
        if step % decay == 0 or step == 1:
            lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)

        targets_3d, inputs_2d = targets_3d.to(device), inputs_2d.to(device)
        
        outputs_3d = model_pos(inputs_2d)       # [5, #b, 48]

        optimizer.zero_grad()

        loss_3d_pos = 0
        for nb in range(num_branches):
            loss_3d_pos += criterion(outputs_3d[nb].reshape(-1,16,3), targets_3d[nb].reshape(-1,16,3))
        
        loss_3d_pos = loss_3d_pos / num_branches
        
        loss_3d_pos.backward()
        if max_norm:
            nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
        optimizer.step()

        epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses)

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                     '| Loss: {loss: .4f}' \
            .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg,
                    ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg)
        bar.next()

        # if i == 10:
        #     break

    bar.finish()

    return epoch_loss_3d_pos.avg, lr_now, step