import csv
from typing import List, Dict
from cp_files import cal_metrics
def generate_path(path: str, algorithm: str) -> str:

    # noise_type = 'noiseless' if noiseless else 'noisy'
    # learned_suffix = '_learned' if learned else ''
    # return f"exp/image_samples/rebuttal/more_algos/{dataset}/{task}_{noise_type}/{algorithm}/{steps}steps{learned_suffix}"
    return f"{path}/{algorithm}"

    # return f"exp/image_samples/ablation/more_steps/{task}/{algorithm}/{steps}steps{learned_suffix}"
    # return f"exp/image_samples/rebuttal/fix_dmps/{task}_noisy/{algorithm}/{steps}steps{learned_suffix}"

def calculate_and_format_metrics(path: str, algorithm: str) -> Dict[str, float]:
    metrics_noiseless = cal_metrics(generate_path(path, algorithm))
    # metrics_noisy = cal_metrics(generate_path(algorithm, learned, steps, False))
    
    formatted_metrics = {
        'PSNR': round(metrics_noiseless['PSNR'], 2),
        'SSIM': round(metrics_noiseless['SSIM'], 4),
        'LPIPS': round(metrics_noiseless['LPIPS'], 4),
        'FID': round(metrics_noiseless['FID'], 2),
    }
    return formatted_metrics

def compare_and_bold(metrics_a: Dict[str, float], metrics_b: Dict[str, float]) -> Dict[str, str]:
    better_metrics = {}
    for key in metrics_a:
        if key.startswith('PSNR') or key.startswith('SSIM'):
            if metrics_a[key] > metrics_b[key]:
                better_metrics[key] = f"<b>{metrics_a[key]}</b>"
            else:
                better_metrics[key] = f"{metrics_a[key]}"
        elif key.startswith('LPIPS') or key.startswith('FID'):
            if metrics_a[key] < metrics_b[key]:
                better_metrics[key] = f"<b>{metrics_a[key]}</b>"
            else:
                better_metrics[key] = f"{metrics_a[key]}"
    return better_metrics

def write_metrics_to_html(path: str, algorithms: List[str], output_file: str):
    with open(output_file, mode='w', encoding='utf-8') as file:
        file.write("<html><body>\n")
        file.write("<table border='1'>\n")
        file.write("<tr><th>algorithm</th><th>PSNR</th><th>SSIM</th><th>LPIPS</th><th>FID</th></tr>\n")
        
        for algorithm in algorithms:
            print(algorithm)
            # metrics_original = calculate_and_format_metrics(algorithm, False, steps)
            # metrics_learned = calculate_and_format_metrics(algorithm, True, steps)
            metrics = calculate_and_format_metrics(path, algorithm)
            
            # better_original = compare_and_bold(metrics_original, metrics_learned)
            
            # better_learned = compare_and_bold(metrics_learned, metrics_original)
            
            file.write(f"><td>{algorithm}</td><td>{metrics['PSNR']}</td><td>{metrics['SSIM']}</td>"
                        f"<td>{metrics['LPIPS']}</td><td>{metrics['FID']}</td></tr>\n")

        
        file.write("</table>\n")
        file.write("</body></html>\n")

path = 'exp/image_samples/ffhq/inp_noisy'
algorithms = ['latent_dps', 'ldir', 'psld', 'resample', 'stsl']
output_file = 'results/ffhq_inp_noisy.html' 

write_metrics_to_html(path, algorithms, output_file)