# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from inception import inception_score
import os
import shutil
import torch
import torchvision.utils
from torch.utils.data import DataLoader
from torchvision import transforms
from config import *



def main() -> None:
    # Create an experiment result folder.
    if os.path.exists(exp_dir):
        shutil.rmtree(exp_dir)
    os.makedirs(exp_dir)


    # Load model weights.
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    # Start the verification mode of the model.
    model.eval()
    # Turn on half-precision inference.
    model.half()

    # Generate images
    generated_images = []
    with torch.no_grad():
        for index in range(10000):
            # Create an image that conforms to the Gaussian distribution.
            fixed_noise = torch.randn([1, 100, 1, 1], device=device).half()
            image = model(fixed_noise)
            generated_images.append(image)

            # Save each generated image
            torchvision.utils.save_image(image, os.path.join(exp_dir, f"{index:03d}.png"))
            print(f"The {index + 1:03d} image is being created using the model...")

    generated_images = torch.cat(generated_images)  # Combine all generated images into a single tensor


    if generated_images.shape[1] == 1:
        # real_img_3channel = real_img.repeat(1, 3, 1, 1)
        gen_img_3channel = generated_images.repeat(1, 3, 1, 1)
    else:
        # real_img_3channel = real_img
        gen_img_3channel = generated_images

    i=inception_score(gen_img_3channel, device, 100, resize=True, splits=10)
    print(f"i {i}") 
    with open('IS.txt', 'a') as f:
        f.write(f'Inception Mean: {float(i[0])}, Inception Std: {float(i[1])}\n')   

if __name__ == "__main__":
    main()