import argparse
import os
import json
import os
import torch
import torch_fidelity
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import logging
import pandas as pd
from helper import set_seeds
import torchmetrics.functional.multimodal.clip_score
from functools import partial
import numpy as np
logging.basicConfig(encoding="utf-8", level=logging.WARNING)
logger = logging.getLogger(__name__)
# Compute FID, KID, and ISC using torch-fidelity
def calculate_metrics(real_dataset, generated_dataset):
    kid_subset_size = min(1000,min(len(real_dataset), len(generated_dataset)))
    metrics_dict = torch_fidelity.calculate_metrics(
        input1=generated_dataset,
        input2=real_dataset,
        cuda=torch.cuda.is_available(),  # Automatically use GPU if available
        isc=True,     # Inception Score (ISC)
        fid=True,     # Fréchet Inception Distance (FID)
        kid=True,     # Kernel Inception Distance (KID)
        prc=True,    # Precision/Recall 
        verbose=False, # Enable verbose mode
        kid_subset_size=kid_subset_size,  # Subset size for KID calculation
        samples_find_deep=True,
        cache=False,
        no_class=True
    )
    
    return metrics_dict


class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
                        
        for file in os.listdir(root_dir):
            if file.lower().endswith(('png', 'jpg', 'jpeg')):
                self.image_paths.append(os.path.join(root_dir, file))

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            logger.warning(e)
            logger.warning(f"Could not open image at {img_path}")
            return self.__getitem__(idx-1)
        
        if self.transform:
            image = self.transform(image)
        
        return image

# Define transformation
class ClipImageFolderDataset(Dataset):
    def __init__(self, metadata_path, transform=None):
        self.transform = transform
        self.image_paths = []
        self.prompts = []
        
        with open(metadata_path, mode="r", encoding="utf-8") as json_file:
            for i, line in enumerate(json_file):
                stripped = line.strip()
                if not stripped:
                    print(f"Skipping line {i} as strip did not work. Line: {line}, strip: {line.strip()}")
                    skipped_lines+=1
                    continue
                try:
                    line2 = json.loads(line.strip())
                    del line2["stat_data"]             
                    self.prompts.append(line2["prompt"])
                    if "image_path" in line2.keys():
                        self.image_paths.append(line2['image_path'])
                    else:
                        self.image_paths.append(line2['path'])                  
                except Exception as e:
                    print(line2)
                    print(e)        
                        
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        prompt = self.prompts[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            logger.warning(e)
            logger.warning(f"Could not open image at {img_path}")
            return self.__getitem__(idx-1)
        
        if self.transform:
            image = self.transform(image)
        
        return image, prompt

# Define transformation

    
class COCODataset(Dataset):
    def __init__(self, img_ids, root_dir="coco2014/val2014", transform=None):
        
        def format_coco_image_id(image_id):
            return f"COCO_val2014_{image_id:012d}.jpg"
        
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.img_ids = img_ids
        
        img_ids = [format_coco_image_id(i) for i in img_ids]
        
        # Collect all image paths from subdirectories

        for file in os.listdir(root_dir):
            if file.lower().endswith(('png', 'jpg', 'jpeg')) and file in img_ids:
                self.image_paths.append(os.path.join(root_dir, file))

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            logger.warning(e)
            logger.warning(f"Could not open image at {img_path}")
            return self.__getitem__(idx-1)
        if self.transform:
            image = self.transform(image)
        
        return image
    
def infer_fid(coco_data_path, gen_data_path):
    
    print(f"Currently investigating {gen_data_path}!")
    
        
    metrics_name = "metrics.json"
    ids = []
    skipped_lines = 0
    with open(f"{gen_data_path}/{metrics_name}", mode="r", encoding="utf-8") as json_file:
        for i, line in enumerate(json_file):
            stripped = line.strip()
            if not stripped:
                print(f"Skipping line {i} as strip did not work. Line: {line}, strip: {line.strip()}")
                skipped_lines+=1
                continue
            try:
                line2 = json.loads(line.strip())
                del line2["stat_data"]             
                ids.append(line2["image_id"])                  
            except Exception as e:
                print(e)        
    if skipped_lines:
        logging.warning(f"Skipped {skipped_lines} lines")

    transform = transforms.Compose([
        transforms.Resize((299,299)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x * 255).to(torch.uint8))
    ])
    real_dataset = COCODataset(ids, coco_data_path, transform)
    
    generated_dataset = ImageFolderDataset(gen_data_path, transform)
    
    return calculate_metrics(generated_dataset, real_dataset)

def infer_clip(gen_data_path):
    clip_score_fn = partial(torchmetrics.functional.multimodal.clip_score, model_name_or_path="openai/clip-vit-base-patch16")
    batch_size = 100
    clip_score = []
    metrics_name = "metrics.json"
    metadata_path = f"{gen_data_path}/{metrics_name}"
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x * 255).to(torch.uint8))
    ])
    dataset = ClipImageFolderDataset(metadata_path, transform = transform)
    dl = DataLoader(dataset, batch_size=batch_size)
    clip_score_summed = 0
    num_samples = 0
    for i, (imgs, prompts) in enumerate(dl):
        try:
            clip_score.append(calculate_clip_score(clip_score_fn, imgs, list(prompts)))
            clip_score_summed+= clip_score[-1] * len(prompts)
            num_samples += len(prompts)
        except Exception as e:
            print(f"Error processing batch {i} with error {e}")
    print(f"CLIP score: {clip_score}")

    if num_samples:
        clip_score_mean = clip_score_summed/num_samples
        std = np.std(clip_score)
        std = round(std.item(),3)
    else:
        clip_score_mean = clip_score_summed
        std = -1
    return clip_score_mean, std  


def calculate_clip_score(clip_score_fn, images, prompts):
    clip_score = clip_score_fn(images, prompts).detach()
    return round(float(clip_score), 4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--coco_data_path", type=str)
    parser.add_argument("--gen_data_path", type=str)
    parser.add_argument("--out_dir",type=str)
    
    args = parser.parse_args()
    set_seeds(args.seed)
    
    coco_data_path = args.coco_data_path
    gen_data_path = args.gen_data_path
    clip_score_mean, clip_score_std = infer_clip(gen_data_path)

    metrics_dict = infer_fid(coco_data_path, gen_data_path)
    metrics_dict['clip_score_std'] = clip_score_std
    metrics_dict['clip_score_mean'] = clip_score_mean
    print(metrics_dict)
    df = pd.DataFrame(metrics_dict, index=[0])
    df.to_csv(f"{args.out_dir}/img_quality.csv", mode='w', index=False)
    
    


