# Modified based on the HRNet repo.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import time
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np


def mdeq_hypersolver_loss(model, z_targ, z_ref, rel_trace, hyp_X, Galphas, z_init, train_step):
    bsz = z_targ.shape[0]

    lam1 = 0.1
    lam2 = 5
    lam3 = 1e-5
    loss_len = rel_trace.shape[1]

    diff = hyp_X[:,-loss_len:] - z_targ[:,None]
    reco_losses = diff.view(*diff.shape[:2], -1).norm(dim=2).mean(0)
    ratio = max(1-train_step/1500,0.005)
    if model.training:
        loss_weights = torch.arange(0,1,0.02).to(reco_losses)**2
        loss_weights = loss_weights[:loss_len] / loss_weights[:loss_len].sum()
        reco_losses = reco_losses * loss_weights * lam1
        loss = reco_losses.sum()
    else:
        loss = torch.tensor(0.0).to(z_targ)
        
    init_est_loss = torch.tensor(0.0).to(loss)
    if not (z_init == 0).all():
        init_est_loss = F.mse_loss(z_init, z_targ) * lam2
        loss = loss + init_est_loss

    alpha_aux_loss = torch.tensor(0.0).to(loss)
    if model.hypsolver.learn_alpha:
        alpha_aux_loss = Galphas.mean()*lam3
        loss = loss + alpha_aux_loss*ratio
    
    ref_reco_loss = (z_targ - z_ref).view(bsz, -1).norm(dim=1).mean()

    return loss.view(-1,1), reco_losses.view(1,-1), ref_reco_loss.view(1,-1), alpha_aux_loss.view(1,-1), init_est_loss.view(1,-1)


class FullModel(nn.Module):
    """
    Distribute the loss on multi-gpu to reduce 
    the memory cost in the main gpu.
    You can check the following discussion.
    https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21
    """
    def __init__(self, model, loss):
        super(FullModel, self).__init__()
        self.model = model
        self.loss = loss

    def forward(self, inputs, labels, train_step=-1, simple=False, **kwargs):
        output_info = self.model(inputs, train_step=train_step, simple=simple, **kwargs)
        if simple:
            y_hyp, hyp_result = output_info
            hyp_loss = self.loss(y_hyp, labels)
            return y_hyp, hyp_loss

        y_targ, y_ref, y_hyp, result, hyp_result, z_init = output_info    
        with torch.no_grad():
            targ_loss = self.loss(y_targ, labels)
            ref_loss = self.loss(y_ref, labels)
            hyp_loss = self.loss(y_hyp, labels)

        bsz = y_targ.shape[0]
        z_targ = result['result']
        z_ref = result['intercept_result']

        all_losses = mdeq_hypersolver_loss(self.model, z_targ, z_ref, hyp_result['rel_trace'],
                                           hyp_result['X'], hyp_result['Galphas'], z_init, train_step=train_step)
        loss, reco_losses, ref_reco_loss, alpha_aux_loss, init_est_loss = all_losses

        return (y_targ, y_ref, y_hyp), \
            loss, reco_losses, ref_reco_loss, targ_loss, ref_loss, hyp_loss, alpha_aux_loss, init_est_loss

def get_world_size():
    if not torch.distributed.is_initialized():
        return 1
    return torch.distributed.get_world_size()

def get_rank():
    if not torch.distributed.is_initialized():
        return 0
    return torch.distributed.get_rank()


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg
    

def create_logger(cfg, cfg_name, phase='train'):
    root_output_dir = Path(cfg.OUTPUT_DIR)
    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    dataset = cfg.DATASET.DATASET
    model = cfg.MODEL.NAME
    cfg_name = os.path.basename(cfg_name).split('.')[0]

    final_output_dir = root_output_dir / dataset / cfg_name

    print('=> creating {}'.format(final_output_dir))
    final_output_dir.mkdir(parents=True, exist_ok=True)

    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
    final_log_file = final_output_dir / log_file
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / (cfg_name + f'-{cfg.MODEL.F_THRES}')
    print('=> creating {}'.format(tensorboard_log_dir))
    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)

    return logger, str(final_output_dir), str(tensorboard_log_dir)


def get_optimizer(cfg, model):
    optimizer = None
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(
            #model.parameters(),
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.TRAIN.LR,
            momentum=cfg.TRAIN.MOMENTUM,
            weight_decay=cfg.TRAIN.WD,
            nesterov=cfg.TRAIN.NESTEROV
        )
    elif cfg.TRAIN.OPTIMIZER == 'adam':
        optimizer = optim.Adam(
            #model.parameters(),
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.TRAIN.LR
        )
    elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
        optimizer = optim.RMSprop(
            #model.parameters(),
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.TRAIN.LR,
            momentum=cfg.TRAIN.MOMENTUM,
            weight_decay=cfg.TRAIN.WD,
            alpha=cfg.TRAIN.RMSPROP_ALPHA,
            centered=cfg.TRAIN.RMSPROP_CENTERED
        )

    return optimizer


def save_checkpoint(states, is_best, output_dir,
                    filename='checkpoint.pth.tar'):
    torch.save(states, os.path.join(output_dir, filename))
    if is_best and 'state_dict' in states:
        torch.save(states['state_dict'],
                   os.path.join(output_dir, 'model_best.pth.tar'))


def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
    """
    Calcute the confusion matrix by given label and pred
    """
    output = pred.cpu().numpy().transpose(0, 2, 3, 1)
    seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
    seg_gt = np.asarray(
    label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int)

    ignore_index = seg_gt != ignore
    seg_gt = seg_gt[ignore_index]
    seg_pred = seg_pred[ignore_index]

    index = (seg_gt * num_class + seg_pred).astype('int32')
    label_count = np.bincount(index)
    confusion_matrix = np.zeros((num_class, num_class))

    for i_label in range(num_class):
        for i_pred in range(num_class):
            cur_index = i_label * num_class + i_pred
            if cur_index < len(label_count):
                confusion_matrix[i_label,
                                 i_pred] = label_count[cur_index]
    return confusion_matrix


def adjust_learning_rate(optimizer, base_lr, max_iters, 
        cur_iters, power=0.9):
    lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
    optimizer.param_groups[0]['lr'] = lr
    return lr