# Modified based on the HRNet repo.

import logging
import os
import time
import numpy as np
import sys

import numpy as np
import numpy.ma as ma
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn import functional as F

from utils.utils import AverageMeter
from utils.utils import get_confusion_matrix
from utils.utils import adjust_learning_rate
from utils.utils import get_world_size, get_rank
from utils.utils import save_checkpoint

logger = logging.getLogger(__name__)

def reduce_tensor(inp):
    """
    Reduce the loss from all processes so that 
    process with rank 0 has the averaged results.
    """
    world_size = get_world_size()
    if world_size < 2:
        return inp
    with torch.no_grad():
        reduced_inp = inp
        dist.reduce(reduced_inp, dst=0)
    return reduced_inp

def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
         trainloader, optimizer, lr_scheduler, model, output_dir, writer_dict, device):
    
    # Training
    model.train()
    model.module.model.hypsolver.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_targ_loss = AverageMeter()
    ave_ref_loss = AverageMeter()
    ave_hyp_loss = AverageMeter()
    ave_init_loss = AverageMeter()
    ave_reco_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    targ_confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    hyp_confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    ref_confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        size = labels.size()
        images = images.to(device)
        labels = labels.long().to(device)
        
        losses_info = model(images, labels, train_step=(lr_scheduler._step_count-1), 
                                     f_thres=config.MODEL.F_THRES,
                                     writer=None)

        (y_targ, y_ref, y_hyp), loss, reco_losses, ref_reco_loss, targ_loss, ref_loss, hyp_loss, alpha_aux_loss, init_est_loss = losses_info
        for suffix in ["targ", "hyp", "ref"]:
            preds = F.interpolate(input=eval(f"y_{suffix}"), size=(size[-2], size[-1]), mode='bilinear', align_corners=True)
            cm = eval(f"{suffix}_confusion_matrix") 
            cm += get_confusion_matrix(
                labels,
                preds,
                size,
                config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

        loss = loss.mean()
        reco_losses = reco_losses.mean(0)
        ref_reco_loss = ref_reco_loss.mean()
        targ_loss = targ_loss.mean()
        hyp_loss = hyp_loss.mean()
        ref_loss = ref_loss.mean()
        alpha_aux_loss = alpha_aux_loss.mean()
        init_est_loss = init_est_loss.mean()

        reduced_loss = reduce_tensor(loss)
        reduced_targ_loss = reduce_tensor(targ_loss)
        reduced_hyp_loss = reduce_tensor(hyp_loss)
        reduced_ref_loss = reduce_tensor(ref_loss)
        reduced_init_loss = reduce_tensor(init_est_loss)
        reduced_reco_loss = reduce_tensor(reco_losses[-1])

        model.zero_grad()
        loss.backward()
        if config['TRAIN']['CLIP'] > 0:
            torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], config['TRAIN']['CLIP'])
        else:
            print("Hmm, you should use clipping for hypersolver")
        optimizer.step()
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_targ_loss.update(reduced_targ_loss.item())
        ave_hyp_loss.update(reduced_hyp_loss.item())
        ave_ref_loss.update(reduced_ref_loss.item())
        ave_init_loss.update(reduced_init_loss.item())
        ave_reco_loss.update(reduced_reco_loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            mIoUs = []
            for suffix in ["targ", "hyp", "ref"]:
                confusion_matrix = torch.from_numpy(eval(f"{suffix}_confusion_matrix")).to(device)
                reduced_confusion_matrix = reduce_tensor(confusion_matrix)

                confusion_matrix = reduced_confusion_matrix.cpu().numpy()
                pos = confusion_matrix.sum(1)
                res = confusion_matrix.sum(0)
                tp = np.diag(confusion_matrix)
                IoU_array = (tp / np.maximum(1.0, pos + res - tp))
                mean_IoU = IoU_array.mean()
                mIoUs.append(mean_IoU)
            
            targ_confusion_matrix = np.zeros(
                (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
            hyp_confusion_matrix = np.zeros(
                (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
            ref_confusion_matrix = np.zeros(
                (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

            if rank == 0:
                print_loss = ave_loss.average() / world_size
                print_init_loss = ave_init_loss.average() / world_size
                print_targ_loss = ave_targ_loss.average() / world_size
                print_hyp_loss = ave_hyp_loss.average() / world_size
                print_ref_loss = ave_ref_loss.average() / world_size
                print_reco_loss = ave_reco_loss.average() / world_size

                msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                    'lr: {:.4f}, Loss: {:.5f}, Recon_Loss: {:.4f}, Init_Loss: {:.4f}, Seg_Loss: {:.5f}({:.5f}|{:.5f}), mIoU: {:.4f}({:.4f}|{:.4f})' .format(
                        epoch, num_epoch, i_iter, epoch_iters, batch_time.average(), 
                        lr, print_loss, print_reco_loss, print_init_loss, print_targ_loss, print_hyp_loss, print_ref_loss, mIoUs[0], mIoUs[1], mIoUs[2])
                logging.info(msg)

                batch_time = AverageMeter()
                ave_loss = AverageMeter()
                ave_targ_loss = AverageMeter()
                ave_hyp_loss = AverageMeter()
                ave_ref_loss = AverageMeter()
                ave_init_loss = AverageMeter()
                ave_reco_loss = AverageMeter()

        global_steps += 1
        writer_dict['train_global_steps'] = global_steps
    

def validate(config, testloader, model, lr_scheduler, epoch, writer_dict, device, simple=False, override_f_thres=-1):
    
    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    model.module.model.hypsolver.eval()
    ave_targ_loss = AverageMeter()
    ave_ref_loss = AverageMeter()
    ave_hyp_loss = AverageMeter()
    hyp_confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    if simple:
        update_list = ["hyp"]
    else:
        targ_confusion_matrix = np.zeros(
            (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
        ref_confusion_matrix = np.zeros(
            (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
        update_list = ["targ", "hyp", "ref"]
    
    f_thres = config.MODEL.F_THRES if override_f_thres < 0 else override_f_thres
    
    with torch.no_grad():
        for ind, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            image = image.to(device)
            label = label.long().to(device)

            start_time = time.time()
            losses_info = model(image, label, train_step=(lr_scheduler._step_count-1),
                                f_thres=f_thres,
                                writer=None,
                                simple=simple)
            if simple:
                y_hyp, hyp_loss = losses_info
            else:
                (y_targ, y_ref, y_hyp), _, _, _, targ_loss, ref_loss, hyp_loss, _, _ = losses_info

            for suffix in update_list:
                pred = F.interpolate(input=eval(f"y_{suffix}"), size=(size[-2], size[-1]), mode='bilinear', align_corners=True)
                loss = eval(f"{suffix}_loss").mean()
                eval(f"ave_{suffix}_loss").update(reduce_tensor(loss).item())

                cm = eval(f"{suffix}_confusion_matrix") 
                cm += get_confusion_matrix(
                    label,
                    pred,
                    size,
                    config.DATASET.NUM_CLASSES,
                    config.TRAIN.IGNORE_LABEL)

    to_return = []
    for suffix in update_list:
        confusion_matrix = torch.from_numpy(eval(f"{suffix}_confusion_matrix")).to(device)
        reduced_confusion_matrix = reduce_tensor(confusion_matrix)

        confusion_matrix = reduced_confusion_matrix.cpu().numpy()
        pos = confusion_matrix.sum(1)
        res = confusion_matrix.sum(0)
        tp = np.diag(confusion_matrix)
        IoU_array = (tp / np.maximum(1.0, pos + res - tp))
        mean_IoU = IoU_array.mean()
        print_loss = eval(f"ave_{suffix}_loss").average()/world_size
        to_return.append((print_loss, mean_IoU, IoU_array))

    return to_return
    

def testval(config, test_dataset, testloader, model, 
        sv_dir='', sv_pred=False):
    model.eval()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    with torch.no_grad():
        for index, batch in enumerate(tqdm(testloader)):
            image, label, _, name = batch
            size = label.size()
            pred = test_dataset.multi_scale_inference(
                        model, 
                        image, 
                        scales=config.TEST.SCALE_LIST, 
                        flip=config.TEST.FLIP_TEST)
            
            if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
                pred = F.interpolate(pred, (size[-2], size[-1]), 
                                   mode='bilinear', align_corners=True)

            confusion_matrix += get_confusion_matrix(
                label,
                pred,
                size,
                config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

            if sv_pred:
                sv_path = os.path.join(sv_dir,'test_val_results')
                if not os.path.exists(sv_path):
                    os.mkdir(sv_path)
                test_dataset.save_pred(pred, sv_path, name)
            
            if index % 100 == 0:
                logging.info('processing: %d images' % index)
                pos = confusion_matrix.sum(1)
                res = confusion_matrix.sum(0)
                tp = np.diag(confusion_matrix)
                IoU_array = (tp / np.maximum(1.0, pos + res - tp))
                mean_IoU = IoU_array.mean()
                logging.info('mIoU: %.4f' % (mean_IoU))

    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    pixel_acc = tp.sum()/pos.sum()
    mean_acc = (tp/np.maximum(1.0, pos)).mean()
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()

    return mean_IoU, IoU_array, pixel_acc, mean_acc

def test(config, test_dataset, testloader, model, 
        sv_dir='', sv_pred=True):
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(tqdm(testloader)):
            image, size, name = batch
            size = size[0]
            pred = test_dataset.multi_scale_inference(
                        model, 
                        image, 
                        scales=config.TEST.SCALE_LIST, 
                        flip=config.TEST.FLIP_TEST)
            
            if pred.size()[-2] != size[0] or pred.size()[-1] != size[1]:
                pred = F.interpolate(pred, (size[-2], size[-1]), mode='bilinear', align_corners=True)

            if sv_pred:
                sv_path = os.path.join(sv_dir,'test_results')
                if not os.path.exists(sv_path):
                    os.mkdir(sv_path)
                test_dataset.save_pred(pred, sv_path, name)
