import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
import argparse
from data import create_dataset
from data.universal_dataset import AlignedDataset_all
from src.model_udbm import (ResidualDiffusion,Trainer, Unet, UnetRes,set_seed)
def parsr_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataroot", type=str, default='/data1/all_in_one')
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--max_dataset_size", type=int, default=float("inf"))
    parser.add_argument('--load_size', type=int, default=256, help='scale images to this size') #568
    parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
    parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
    parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
    parser.add_argument('--no_flip', type=bool, default=True, help='if specified, do not flip the images for data augmentation')
    parser.add_argument("--bsize", type=int, default=2)
    parser.add_argument("--ckpt_path_s1", type=str, default="/pretrain/model_nafnet_32_16.pt")
    opt = parser.parse_args()
    return opt

sys.stdout.flush()
set_seed(10)

save_and_sample_every = 1000
if len(sys.argv) > 1:
    sampling_timesteps = int(sys.argv[1])
else:
    sampling_timesteps = 5
train_num_steps = 100000

condition = True

train_batch_size = 1
num_samples = 1
image_size = 256


opt = parsr_args()

results_folder = "./ckpt_universal/model_udbm"

num_unet = 1
objective = 'pred_res'
test_res_or_noise = "res"
sampling_timesteps = 1
sum_scale = 0.01
ddim_sampling_eta = 0.
delta_end = 1.8e-3

model = UnetRes(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    num_unet=num_unet,
    condition=condition,
    objective=objective,
    test_res_or_noise = test_res_or_noise,
    test_mode=False
)
diffusion = ResidualDiffusion(
    model,
    image_size=image_size,
    timesteps=1000,           # number of steps
    delta_end = delta_end,
    sampling_timesteps=sampling_timesteps,
    ddim_sampling_eta=ddim_sampling_eta,
    objective=objective,
    loss_type='l1',            # L1 or L2
    condition=condition,
    sum_scale=sum_scale,
    test_res_or_noise = test_res_or_noise,
)
import numpy as np
TASK_MAPPING = {
    'rain': ['rain1', 'rain2', 'rain3', 'rain4', 'rain5'],
    'snow': ['snow1', 'snow2'],
    'real_dark': ['real_dark_mef', 'real_dark_dice', 'real_dark_npe'],
    'real_blur': ['real_hide', 'real_j', 'real_r'],
    'cd11': ['l', 'h', 'r', 's', 'lh', 'lr', 'ls', 'hr', 'hs', 'lhr', 'lhs']
}


input_tasks = ['light_only','rain','blur','fog','snow'] 
input_tasks=['fog']

final_summary = {}

print(f"Start processing tasks: {input_tasks}")

# --- 4. 主循环 ---
for main_task in input_tasks:
    print(f"\n{'='*20} Processing Main Task: {main_task} {'='*20}")
    
    sub_tasks = TASK_MAPPING.get(main_task, [main_task])
    
    main_task_metrics = []

    for sub_task in sub_tasks:
        print(f"--- Running sub-task: {sub_task} ---")
        
        dataset = AlignedDataset_all(opt, image_size, augment_flip=False, equalizeHist=True, crop_patch=False, generation=False, task=sub_task)

        trainer = Trainer(
            diffusion,
            dataset,
            opt,
            train_batch_size=train_batch_size,
            num_samples=num_samples,
            train_lr=2e-4,
            train_num_steps=train_num_steps,
            gradient_accumulate_every=2,
            ema_decay=0.995,
            amp=False,
            convert_image_to="RGB",
            results_folder=results_folder,
            condition=condition,
            save_and_sample_every=save_and_sample_every,
            num_unet=num_unet,
        )

        if trainer.accelerator.is_local_main_process:
            trainer.load(600)
            trainer.set_results_folder('./result') 
            
            current_metric = trainer.test(last=True, task=sub_task)
            main_task_metrics.append(current_metric)
            print(f"Finished {sub_task}: {current_metric}")

    if len(main_task_metrics) > 0:
        keys = main_task_metrics[0].keys()
        avg_dict = {}
        for k in keys:

            values = [d[k] for d in main_task_metrics if k in d]
            avg_dict[k] = np.mean(values)
        
        # 存入汇总字典
        final_summary[main_task] = avg_dict

        print(f"\n>>> Average Scores for [{main_task}]:")
        for k, v in avg_dict.items():
            print(f"{k}: {v:.4f}")
    else:
        print(f"No metrics collected for {main_task}")


print("\n" + "#"*30)
print("FINAL SUMMARY REPORT")
print("#"*30)
for task_name, scores in final_summary.items():
    score_str = ", ".join([f"{k}: {v:.4f}" for k, v in scores.items()])
    print(f"{task_name}: {score_str}")