from argparse import ArgumentParser
import os
import json
import sys
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from criteria.lpips.lpips import LPIPS
from datasets.gt_res_dataset import GTResDataset


def parse_args():
    parser = ArgumentParser(add_help=False)
    parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
    parser.add_argument('--data_path', type=str, default='results')
    parser.add_argument('--gt_path', type=str, default='gt_images')
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--is_cars', action='store_true')
    args = parser.parse_args()
    return args


def run(args):
    resize_dims = (256, 256)
    if args.is_cars:
        resize_dims = (192, 256)
    transform = transforms.Compose([transforms.Resize(resize_dims),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    print('Loading dataset')
    dataset = GTResDataset(root_path=args.data_path,
                           gt_dir=args.gt_path,
                           transform=transform)

    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=int(args.workers),
                            drop_last=True)

    if args.mode == 'lpips':
        loss_func = LPIPS(net_type='alex')
    elif args.mode == 'l2':
        loss_func = torch.nn.MSELoss()
    else:
        raise Exception('Not a valid mode!')
    loss_func.cuda()

    global_i = 0
    scores_dict = {}
    all_scores = []
    for result_batch, gt_batch in tqdm(dataloader):
        for i in range(args.batch_size):
            loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda()))
            all_scores.append(loss)
            im_path = dataset.pairs[global_i][0]
            scores_dict[os.path.basename(im_path)] = loss
            global_i += 1

    all_scores = list(scores_dict.values())
    mean = np.mean(all_scores)
    std = np.std(all_scores)
    result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
    print('Finished with ', args.data_path)
    print(result_str)

    out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
        f.write(result_str)
    with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
        json.dump(scores_dict, f)


if __name__ == '__main__':
    args = parse_args()
    run(args)
