import os
import re
import torch
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from PIL import Image
from torchmetrics import StructuralSimilarityIndexMeasure
import json
import pandas as pd
import argparse


def image_grid_to_tensor(grid_image_path, img_size):
    grid_image = Image.open(grid_image_path).convert('RGB')
    transform = transforms.ToTensor()
    grid_tensor = transform(grid_image)

    # Split the grid image into 100 individual images
    images = []
    for i in range(10):
        for j in range(10):
            x_start = j * img_size
            y_start = i * img_size
            x_end = x_start + img_size
            y_end = y_start + img_size
            img = grid_tensor[:, y_start:y_end, x_start:x_end]
            images.append(img)

    # Stack the images into a single tensor
    images_tensor = torch.stack(images)
    return images_tensor



def evaluate_backdoor_one_grid(grid_image_path, target_image_path, img_size):
    target_image = Image.open(target_image_path).convert('RGB')

    # Convert images to tensors
    transform = transforms.ToTensor()
    target_tensor = transform(target_image)

    images_tensor = image_grid_to_tensor(grid_image_path, img_size)

    # Compute the L1 distance between each image and the target image
    l1_distances = torch.abs(images_tensor - target_tensor).sum(dim=[1, 2, 3])
    
    # Compute attack success rate
    ASR_threshold = 500
    ASR = (l1_distances < ASR_threshold).sum().item() / len(l1_distances)

    # Compute the Mean Squared Error (MSE) between all images and the target image
    MSE = torch.mean((images_tensor - target_tensor) ** 2).item()

    ssim = StructuralSimilarityIndexMeasure()
    SSIM = ssim(images_tensor, target_tensor.expand_as(images_tensor)).item()

    return ASR, MSE, SSIM

def evaluate_one_model(folder_path, target_image_path, img_size):
    backdoor_samples_path = os.path.join(folder_path, "backdoor_samples")
    clean_samples_path = os.path.join(folder_path, "samples")

    evaluation_results_one_model = {}
    for filename in os.listdir(backdoor_samples_path):
        if filename.endswith("_noclip.png") and filename[:4].isdigit():
            grid_image_path = os.path.join(backdoor_samples_path, filename)
            epoch_number = filename[2:4]
            ASR, MSE, SSIM = evaluate_backdoor_one_grid(grid_image_path, target_image_path, img_size)

            evaluation_results_one_model[epoch_number] = {
                "ASR": ASR,
                "MSE": MSE,
                "SSIM": SSIM
            }
    
    # evaluation_clean_results_one_model = {}
    # eval_epoch = "0049"
    # clean_filename = eval_epoch+"_noclip.png"
    # grid_clean_image_path = os.path.join(clean_samples_path, clean_filename)
    # FID_score = evaluate_clean_one_grid(grid_clean_image_path, img_size)

    return evaluation_results_one_model

def evaluate_all(model_type, img_size, base_dir = "./"):
    if model_type == "DDPM":
        pattern = re.compile(r"res_DDPM-CIFAR10-32_CIFAR10_ep(\d+)_sde_c1.0_p(\d+\.\d+)_epr0.0_(\w+)-(\w+)_psi1_lr0.0002_vp1.0_ve1.0")
    elif model_type == "NCSN":
        pattern = re.compile(r"res_NCSN_CIFAR10_my_CIFAR10_ep(\d+)_sde_c1.0_p(\d+\.\d+)_epr0.0_(\w+)-(\w+)_psi0.0_lr2e-05_vp1.0_ve1.0_flex_new-set")

    evaluation_results_all_models = []
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)

        # Check if it's a directory and matches the pattern
        if os.path.isdir(folder_path) and pattern.match(folder_name):
            # Extract the variable parts using the regex pattern
            match = pattern.match(folder_name)
            if match:
                num_epoch = match.group(1)
                poison_rate = match.group(2)
                trigger = match.group(3)
                target = match.group(4)

                if target == "HAT":
                    target_image_path = "static/HAT_target.png"
                elif target == "CAT":
                    target_image_path = "static/CAT_target.png"

                evaluation_results_one_model = evaluate_one_model(folder_path, target_image_path, img_size)
                model_infor = {}
                model_infor["num_epoch"] = num_epoch
                model_infor["poison_rate"] = poison_rate
                model_infor["trigger"] = trigger
                model_infor["target"] = target

                evaluate_one_model_extend = {}
                evaluate_one_model_extend["model_infor"] = model_infor
                evaluate_one_model_extend["performance"] = evaluation_results_one_model
                evaluation_results_all_models.append(evaluate_one_model_extend)
    return evaluation_results_all_models

def evaluate_by_poison_rate(evaluation_results, target_eval, poison_rate_range):
    evaluation_results_list = []
    for evaluation_result_one_model in evaluation_results:
        poison_rate = float(evaluation_result_one_model["model_infor"]["poison_rate"])
        target = evaluation_result_one_model["model_infor"]["target"]
        trigger = evaluation_result_one_model["model_infor"]["trigger"]
        num_epoch= evaluation_result_one_model["model_infor"]["num_epoch"]
        epoch_best = None
        ASR_best = 0
        if (poison_rate>=poison_rate_range[0] and poison_rate<=poison_rate_range[1]) and target==target_eval:
            print(poison_rate)
            print(target)
            print(trigger)
            for current_epoch in range(0, int(num_epoch)-1):
                if not(current_epoch == 0 or current_epoch % 2 == 1):
                    continue
                ASR_current_epoch = evaluation_result_one_model["performance"][f"{current_epoch:02d}"]["ASR"]
                if ASR_current_epoch > ASR_best:
                    ASR_best = ASR_current_epoch
                    epoch_best = current_epoch
            if epoch_best==None:
                ASR = None
                MSE = None
                SSIM = None
            else:
                ASR = ASR_best
                MSE = evaluation_result_one_model["performance"][f"{epoch_best:02d}"]["MSE"]
                SSIM = evaluation_result_one_model["performance"][f"{epoch_best:02d}"]["SSIM"]
            evaluation_results_list.append([trigger, target, poison_rate, epoch_best, ASR, MSE, SSIM])
    
    column_names = ["trigger", "target", "poison_rate", "best_epoch", "ASR", "MSE", "SSIM"]
    df_evaluation_results = pd.DataFrame(evaluation_results_list, columns=column_names)
    df_evaluation_results = df_evaluation_results.sort_values(by="poison_rate", ascending=True)
    df_evaluation_results = df_evaluation_results.reset_index(drop=True)
    return df_evaluation_results

def evaluation_by_epoch(evaluation_results, trigger_eval, target_eval, poison_rate_eval):
    evaluation_result_target_model = None
    for evaluation_result_one_model in evaluation_results:
        poison_rate = float(evaluation_result_one_model["model_infor"]["poison_rate"])
        target = evaluation_result_one_model["model_infor"]["target"]
        trigger = evaluation_result_one_model["model_infor"]["trigger"]
        num_epoch = evaluation_result_one_model["model_infor"]["num_epoch"]
        if poison_rate==poison_rate_eval and trigger==trigger_eval and target==target_eval:
            evaluation_result_target_model = evaluation_result_one_model
            break
    if evaluation_result_target_model == None:
        print("This model doesn't exist")
        return
    performance_target_model = evaluation_result_target_model["performance"]
    
    df = pd.DataFrame.from_dict(performance_target_model, orient="index")
    df.index.name = "epoch"
    df.reset_index(inplace=True)
    df["epoch"] = df["epoch"].astype(int)  # Convert to int for proper sorting
    df = df.sort_values(by="epoch", ascending=True).reset_index(drop=True)
    return df

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_mode', choices=['across_poison_rates', 'specific_poison_rate'], default='across_poison_rates')
    parser.add_argument('--poison_rate_min', type=float, default=0.002)         # only used when eval_mode=='across_poison_rates'
    parser.add_argument('--poison_rate_max', type=float, default=1)             # only used when eval_mode=='across_poison_rates'
    parser.add_argument('--poison_rate_specific', type=float, default=0.005)    # only used when eval_mode=='specific_poison_rate'
    parser.add_argument('--target', choices=['HAT', 'CAT'], default='HAT')
    parser.add_argument('--dataset_name', choices=['CIFAR_10', 'CELEBA_HQ'], default='CIFAR_10')
    parser.add_argument('--model_type', choices=['DDPM', 'NCSN'], default='DDPM')
    parser.add_argument('--trigger', choices=['TooBad', 'STOP_SIGN_14', 'STOP_SIGN_18', 'BOX_18'], default='STOP_SIGN_14')
    parser.add_argument('--recompute', choices=[True, False], default=True)
    args = parser.parse_args()

    model_type_eval = args.model_type
    target_eval = args.target
    trigger_eval = args.trigger
    eval_mode = args.eval_mode
    poison_rate_min = args.poison_rate_min
    poison_rate_max = args.poison_rate_max
    poison_rate_specific = args.poison_rate_specific
    dataset_name = args.dataset_name
    recompute = args.recompute


    if trigger_eval=="TooBad":
        trigger_eval = trigger_eval+"_"+model_type_eval+"_"+dataset_name+"_"+target_eval

    file_path = f"evaluation_results_{model_type_eval}_{dataset_name}.json"
    if not os.path.exists(file_path) or recompute==True:
        evaluation_results = evaluate_all(model_type=model_type_eval, img_size=32, base_dir="./") 
        with open(file_path, "w") as file:
            json.dump(evaluation_results, file, indent=4)
    else:
        with open(file_path, "r") as file:
            evaluation_results = json.load(file)
    
    if eval_mode=="across_poison_rates":
        df_results = evaluate_by_poison_rate(evaluation_results, target_eval, poison_rate_range=[poison_rate_min, poison_rate_max])
        print(df_results)
    elif eval_mode=="specific_poison_rate":
        df_results = evaluation_by_epoch(evaluation_results, trigger_eval, target_eval, poison_rate_eval=poison_rate_specific)
        print(df_results)
