# Modified based on the HRNet repo.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import logging
import numpy as np
import sys

import torch

from core.cls_evaluate import accuracy
sys.path.append("../")
from utils.utils import save_checkpoint, mdeq_hypersolver_loss
import random
from tqdm import tqdm


logger = logging.getLogger(__name__)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

        
def train(config, train_loader, model, criterion, optimizer, lr_scheduler, epoch,
          output_dir, tb_log_dir, writer_dict, topk=(1,5)):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    targ_losses = AverageMeter()
    ref_losses = AverageMeter()
    hyp_losses = AverageMeter()
    init_losses = AverageMeter()
    final_reco_losses = AverageMeter()
    targ_top1 = AverageMeter()
    targ_top5 = AverageMeter()
    ref_top1 = AverageMeter()
    ref_top5 = AverageMeter()
    hyp_top1 = AverageMeter()
    hyp_top5 = AverageMeter()
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']

    # switch to train mode
    model.train()
    model.module.hypsolver.train()

    end = time.time()
    total_batch_num = len(train_loader)
    effec_batch_num = int(config.PERCENT * total_batch_num)
    for i, (input, target) in enumerate(train_loader):
        # train on partial training data
        if i >= effec_batch_num:
            break
            
        # measure data loading time
        data_time.update(time.time() - end)

        # compute output (different from seg_function, compute_jac_loss already includes dropping)
        rand_f_thres = config.MODEL.F_THRES + (random.randint(-2,1) if config.MODEL.RAND_F_THRES else 0)
        preds_info = model(input, train_step=(lr_scheduler._step_count-1), 
                           compute_jac_loss=False,
                           f_thres=config.MODEL.F_THRES,
                           writer=None)
        target = target.cuda(non_blocking=True)
        y_targ, y_ref, y_hyp, z_targ, z_ref, z_hyp, rel_trace, abs_trace, hyp_X, Galphas, z_init = preds_info
        
        # Hypersolver losses
        targ_loss = criterion(y_targ, target)
        ref_loss = criterion(y_ref, target)
        hyp_loss = criterion(y_hyp, target)
        
        all_losses = mdeq_hypersolver_loss(model.module, z_targ, z_ref, rel_trace, hyp_X, Galphas, z_init, train_step=(lr_scheduler._step_count-1))
        loss, reco_losses, ref_reco_loss, alpha_aux_loss, init_est_loss = all_losses

        # compute gradient and do update step
        optimizer.zero_grad()
        loss.backward()
        if config['TRAIN']['CLIP'] > 0:
            torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], config['TRAIN']['CLIP'])
        else:
            print("Hmm, you should use clipping for hypersolver")
        optimizer.step()
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # measure accuracy and record loss
        losses.update(loss.mean().item(), input.size(0))
        targ_losses.update(targ_loss.mean().item(), input.size(0))
        ref_losses.update(ref_loss.mean().item(), input.size(0))
        hyp_losses.update(hyp_loss.mean().item(), input.size(0))
        init_losses.update(init_est_loss.mean().item(), input.size(0))
        final_reco_losses.update(reco_losses.mean(0)[-1].item(), input.size(0))

        for suffix in ["targ", "ref", "hyp"]:
            prec1, prec5 = accuracy(eval(f"y_{suffix}"), target, topk=topk)
            top1, top5 = eval(f"{suffix}_top1"), eval(f"{suffix}_top5")
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

        if i % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{0}][{1}/{2}] ({3})\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)  Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                  'Cls_Loss {targ_loss.avg:.3f} ({hyp_loss.avg:.3f}|{ref_loss.avg:.3f})\t' \
                  'Acc@1 {targ_top1.avg:.3f} ({hyp_top1.avg:.3f}|{ref_top1.avg:.3f})\t' \
                  'Acc@5 {targ_top5.avg:.3f} ({hyp_top5.avg:.3f}|{ref_top5.avg:.3f})\t' \
                  'L_reco {reco_loss.avg:.3f}   L_init {init_loss.avg:.3f}'.format(
                      epoch, i, effec_batch_num, global_steps, batch_time=batch_time,
                      speed=input.size(0)/batch_time.avg,
                      data_time=data_time, targ_loss=targ_losses, ref_loss=ref_losses, hyp_loss=hyp_losses, 
                      targ_top1=targ_top1, targ_top5=targ_top5, hyp_top1=hyp_top1, hyp_top5=hyp_top5, 
                      ref_top1=ref_top1, ref_top5=ref_top5, reco_loss=final_reco_losses, init_loss=init_losses)
            logger.info(msg)

            losses = AverageMeter()
            targ_losses = AverageMeter()
            ref_losses = AverageMeter()
            hyp_losses = AverageMeter()
            init_losses = AverageMeter()
            final_reco_losses = AverageMeter()
            targ_top1 = AverageMeter()
            targ_top5 = AverageMeter()
            ref_top1 = AverageMeter()
            ref_top5 = AverageMeter()
            hyp_top1 = AverageMeter()
            hyp_top5 = AverageMeter()

        global_steps += 1
        writer_dict['train_global_steps'] = global_steps

def validate(config, val_loader, model, criterion, lr_scheduler, epoch, output_dir, tb_log_dir,
             writer_dict=None, topk=(1,5), simple=False, cached_targ_t1=None, cached_targ_t5=None,
             cached_ref_t1=None, cached_ref_t5=None, override_f_thres=-1):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    hyp_top1 = AverageMeter()
    hyp_top5 = AverageMeter()
    
    if simple:
        update_list = ["hyp"]
    else:
        targ_top1 = AverageMeter()
        targ_top5 = AverageMeter()
        ref_top1 = AverageMeter()
        ref_top5 = AverageMeter()
        update_list = ["targ", "ref", "hyp"]

    # switch to evaluate mode
    model.eval()
    if simple:
        model.module.eval()
    model.module.hypsolver.eval()

    f_thres = config.MODEL.F_THRES if override_f_thres < 0 else override_f_thres

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda()
            # compute output
            net = model.module if simple else model
            preds_info = net(input, 
                            train_step=(lr_scheduler._step_count-1),
                            f_thres=f_thres,
                            simple=simple)
            target = target.cuda(non_blocking=True)

            if simple:
                y_hyp = preds_info
            else:
                y_targ, y_ref, y_hyp, z_targ, z_ref, z_hyp, abs_trace, rel_trace, hyp_X, Galphas, z_init = preds_info

            # measure accuracy and record loss
            for suffix in update_list:
                output = eval(f"y_{suffix}")
                prec1, prec5 = accuracy(output, target, topk=topk)
                top1, top5 = eval(f"{suffix}_top1"), eval(f"{suffix}_top5")
                top1.update(prec1[0], input.size(0))
                top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    if simple:
        if cached_targ_t1 is not None:
            msg = f'Test: {f_thres}  ' + 'Acc@1 {targ_top1:.3f} ({hyp_top1.avg:.3f}|{ref_top1:.3f})\t' \
                'Acc@5 {targ_top5:.3f} ({hyp_top5.avg:.3f}|{ref_top5:.3f})\t'.format(
                    targ_top1=cached_targ_t1, targ_top5=cached_targ_t5, ref_top1=cached_ref_t1, ref_top5=cached_ref_t5,
                    hyp_top1=hyp_top1, hyp_top5=hyp_top5)
            logger.info(msg)
        logger.info(f"Acc@1 {hyp_top1.avg}  Acc@5 {hyp_top5.avg}  Avg. time per batch: {ave_time.avg}")
        return hyp_top1.avg, hyp_top5.avg
    else:
        msg = f'Test: {f_thres}  ' + 'Acc@1 {targ_top1.avg:.3f} ({hyp_top1.avg:.3f}|{ref_top1.avg:.3f})\t' \
                'Acc@5 {targ_top5.avg:.3f} ({hyp_top5.avg:.3f}|{ref_top5.avg:.3f})\t'.format(
                    targ_top1=targ_top1, targ_top5=targ_top5, hyp_top1=hyp_top1, hyp_top5=hyp_top5, 
                    ref_top1=ref_top1, ref_top5=ref_top5)
        logger.info(msg)

    return targ_top1.avg, targ_top5.avg, ref_top1.avg, ref_top5.avg, hyp_top1.avg, hyp_top5.avg
