
import os
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
#import cv2
import time
import logging
from cross_f1 import *
from utils import *

def validate(train_loader, model, criterion, conf):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    scores = AverageAccMeter()
    mscores = AverageAccMeter()
    ascores = AverageAccMeter()
    end = time.time()
    model.eval()

    time_start = time.time()
    pbar = tqdm(train_loader, dynamic_ncols=True, total=len(train_loader))

    features = []
    paths = []

    for idx, (input, target , path) in enumerate(pbar):

        paths.extend(path)
        data_time.add(time.time() - end)
        input = input.cuda()
        target = target.cuda()

        if 'inception' in conf.netname:
            output = model(input)
        else:
            output,_,moutput , patches , quantized  = model(input)
            features.append(quantized.detach())
        scores.add(output.data, target)
        if 'midlevel' in conf:
            if conf.midlevel:
                mscores.add(moutput.data, target)
                ascores.add(output+moutput.data, target)

        loss = torch.mean(criterion(output, target))
        losses.add(loss.item(), input.size(0))
        del loss,output

        # measure elapsed time
        batch_time.add(time.time() - end)
        end = time.time()
        pbar.set_postfix(batch_time=batch_time.value(), data_time=data_time.value(), loss=losses.value())
    features = torch.cat(features)
    _ , attrs = get_image_properties(paths)
    attrs = torch.from_numpy(attrs)
    binary_features = features_to_binary_features(features , bits=conf.bits)

    _f1 = attribute_f1_seq(attrs , binary_features.transpose(0,1)).item()
    print(_f1)
    return scores.value(), losses.value(),mscores.value(),ascores.value() , _f1
