from __future__ import print_function
import argparse
import torch
from PIL import Image
from torchvision.transforms import ToTensor
import copy
import numpy as np
import os

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--input_image', type=str, default="cat.jpg", help='input image to use')
parser.add_argument('--model1', type=str, default='model_epoch_30_seed_1.pth', help='model file to use')
parser.add_argument('--model2', type=str, default='model_epoch_30_seed_1000.pth', help='model file to use')
parser.add_argument('--folder', default="results", type=str, help='where to save the output image')
parser.add_argument('--cuda', action='store_true', help='use cuda')
opt = parser.parse_args()

print(opt)
img = Image.open(opt.input_image).convert('YCbCr')
y, cb, cr = img.split()
img_to_tensor = ToTensor()
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])
model1 = torch.load(opt.model1)
model2 = torch.load(opt.model2)
model = torch.load(opt.model2)
if opt.cuda:
    model1 = model1.cuda()
    model2 = model2.cuda()
    model = model.cuda()
    input = input.cuda()
sd1 = model1.state_dict()
sd2 = model2.state_dict()
sd = model.state_dict()

out_imgs = []
for i, lmd in enumerate(torch.arange(0,1.01,.1)):
    for key in sd:
        sd[key] = (1 - lmd) * sd1[key] + lmd * sd2[key]
    model.load_state_dict(sd)
    out = model(input)
    out = out.cpu()
    out_img_y = out[0].detach().numpy()
    out_img_y *= 255.0
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

    out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
    if not os.path.exists(opt.folder):
        os.makedirs(opt.folder)
    path = opt.folder+'/'+str(i)+'.png'
    out_img.save(path)
    print('output image saved to ', path)
