import logging
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from utils_training.utils import flow2kps
from utils_training.evaluation import Evaluator
from utils_training.logger import Logger
import numpy as np
import cv2
import os.path as osp
import os


def EPE(input_flow, target_flow, sparse=True, mean=True, sum=False):

    EPE_map = torch.norm(target_flow-input_flow, 2, 1)
    batch_size = EPE_map.size(0)
    if sparse:
        # invalid flow is defined with both flow coordinates to be exactly 0
        mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)

        EPE_map = EPE_map[~mask]
    if mean:
        return EPE_map.mean()
    elif sum:
        return EPE_map.sum()
    else:
        return EPE_map.sum()/torch.sum(~mask)


def train_epoch(net,
                optimizer,
                train_loader,
                device,
                epoch,
                train_writer, write_batch_idx=50):
    n_iter = epoch*len(train_loader)
    
    net.train()
    running_total_loss = 0

    for idx, mini_batch in enumerate(train_loader):
        optimizer.zero_grad()
        flow_gt = mini_batch['flow'].to(device)

        pred_flow = net(mini_batch['trg_img'].to(device),
                         mini_batch['src_img'].to(device))
        
        Loss = EPE(pred_flow, flow_gt) 
        Loss.backward()
        optimizer.step()

        running_total_loss += Loss.item()
        n_iter += 1
        if idx % write_batch_idx == 0:
            msg = '*** %s ' % 'Training'
            msg += '[Epoch: %02d] ' % epoch if epoch != -1 else ''
            msg += '[Batch: %04d/%04d] ' % (idx+1, len(train_loader))
            if epoch != -1:
                msg += 'L: %6.5f  ' % Loss.item()
                msg += 'Avg L: %6.5f  ' % (running_total_loss / (idx + 1))
            Logger.info(msg)
    running_total_loss /= len(train_loader)
    return running_total_loss


def validate_epoch(net,
                   val_loader,
                   device,
                   epoch, write_batch_idx=50):
    net.eval()
    running_total_loss = 0

    with torch.no_grad():
        pbar = tqdm(enumerate(val_loader), total=len(val_loader))
        pck_array = []
        msg = '*** %s ' % 'Validation'
        for idx, mini_batch in enumerate(val_loader):
            flow_gt = mini_batch['flow'].to(device)
            pred_flow = net(mini_batch['trg_img'].to(device),
                            mini_batch['src_img'].to(device))

            estimated_kps = flow2kps(mini_batch['trg_kps'].to(device), pred_flow, mini_batch['n_pts'].to(device))


            eval_result = Evaluator.eval_kps_transfer(estimated_kps.cpu(), mini_batch)
            
            Loss = EPE(pred_flow, flow_gt) 

            pck_array += eval_result['pck']

            running_total_loss += Loss.item()
            if idx % write_batch_idx == 0:
                msg += '[Epoch: %02d] ' % epoch if epoch != -1 else ''
                msg += '[Batch: %04d/%04d] ' % (idx+1, len(val_loader))
                if epoch != -1:
                    msg += 'L: %6.5f  ' % Loss.item()
                    msg += 'Avg L: %6.5f  ' % (running_total_loss / (idx + 1))
        mean_pck = sum(pck_array) / len(pck_array)
        msg += 'mean PCK: %2.5f  ' % mean_pck
        Logger.info(msg)

    return running_total_loss / len(val_loader), mean_pck