from matrics_calculator import MetricsCalculator
import glob
import os
from tqdm import tqdm
import tqdm
import torch
from PIL import Image
import numpy as np
import csv
import argparse
class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):

        path = self.files[i]
        try:

            img = np.array(Image.open(path).convert('RGB').resize((1536,512)))
        except:
            img = np.array(Image.open(self.files[1]).convert('RGB').resize((1536,512)))
        
        return path, img
    
def argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default="cuda:0", help='Path to style images.')
    parser.add_argument('--result_dir', type=str, required=True, help='Path to content images.')
    parser.add_argument('--avg_results_path', type=str, required=True, help='Path to stylized images.')
    parser.add_argument('--save_file', type=str, required=True, help='Path to save_results')
    args = parser.parse_args()
    return args

def main(args):
    device = args.device

    metrics_calculator=MetricsCalculator(device)


    result_dir = args.result_dir
    save_file = args.save_file
    image_paths = glob.glob(f"{result_dir}/*")
    print(result_dir)
    dataset = ImagePathDataset(image_paths)
    dataloader = torch.utils.data.DataLoader(dataset,
                                                batch_size=8,
                                                shuffle=False,
                                                drop_last=False,
                                                num_workers=8)

    list_dicts = []
    distance = 0.0
    all_psnr = 0.0
    all_lpips = 0.0
    all_mse = 0.0
    all_ssim = 0.0
    i=0

    for batch in dataloader:
        for result_path, img_np in zip(batch[0], batch[1]):
            split_size = img_np.shape[1] // 3
            img_struct = img_np[:,:split_size,:]
            img_result = img_np[:,-split_size:,:]
            struct_distance = metrics_calculator.calculate_structure_distance(img_result,img_struct).item()
            psnr = metrics_calculator.calculate_psnr(img_result,img_struct)
            lpips_ = metrics_calculator.calculate_lpips(img_result,img_struct)
            mse = metrics_calculator.calculate_mse(img_result,img_struct)
            ssim = metrics_calculator.calculate_ssim(img_result,img_struct)
            i+=1
            print(i)
            distance += struct_distance
            print(distance)
            all_psnr += psnr
            all_lpips += lpips_
            all_mse += mse
            all_ssim += ssim
            data_dict={}
            data_dict["image_path"] = result_path
            data_dict["struct_distance"] = struct_distance
            data_dict["psnr"] = psnr
            data_dict["lpips"] = lpips_
            data_dict["mse"] = mse
            data_dict["ssim"] = ssim
            list_dicts.append(data_dict)

    avg_distance = float(distance) / len(list_dicts)
    avg_psnr = float(all_psnr) / len(list_dicts)
    avg_lpips = float(all_lpips) / len(list_dicts)
    avg_mse = float(all_mse) / len(list_dicts)
    avg_ssim = float(all_ssim) / len(list_dicts)
    print(f"{avg_distance=}")
    print(f"{avg_psnr=}")
    print(f"{avg_lpips=}")
    print(f"{avg_mse=}")
    print(f"{avg_ssim=}")
    with open(args.avg_results_path,'w') as f:
        f.write(f"image number: {len(list_dicts)}")
        f.write(f"{avg_distance=}\n")
        f.write(f"{avg_psnr=}\n")
        f.write(f"{avg_lpips=}\n")
        f.write(f"{avg_mse=}\n")
        f.write(f"{avg_ssim=}\n")
    with open(save_file, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=list_dicts[0].keys())
        writer.writeheader()
        for row in list_dicts:
            writer.writerow(row)
            

if __name__== '__main__':
    args = argparser()
    main(args)