import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import copy
from dataset import MetricDataset
from metric import Predictor
from tqdm import tqdm

tp = 'sem'
data_root = '~/Project/pytorch/new_data/metric_{}'.format(tp)
metric_datasets = {
    x: MetricDataset(
        data_root=data_root,
        split=x
    )
    for x in ['train', 'val']
}
dataloaders = {
    x: DataLoader(metric_datasets[x], batch_size=16, shuffle=True, num_workers=4)
    for x in ['train', 'val']
}
dataset_sizes = {x: len(metric_datasets[x]) for x in ['train', 'val']}


model = Predictor()
# load ckpt
model.load_state_dict(torch.load('./{}_metric.pth'.format(tp)))
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def eval_model(model):
    for phase in ['train', 'val']:
        model.eval()
        running_loss = 0.0
        for inputs, labels in tqdm(dataloaders[phase]):
            inputs = inputs.to(device).float()
            labels = labels.to(device).float()
            with torch.no_grad():
                outputs = model(inputs)
                diff = torch.abs(outputs - labels)
                loss = torch.mean(diff)
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / dataset_sizes[phase]
        print(f'{phase} Loss: {epoch_loss:.4f}')


eval_model(model)
