from utils import general
import torch
from models import models_3d
import cv2
import numpy as np
import os
import random 
import importlib
import argparse
import copy
from datetime import datetime
from skimage.transform import rescale
import time
import json
# exp = 'IDIR + an.TV.reg. with 20 seed'

def preprocess_fn(image, labels, voxel_size):
    def center_crop(img):
        C = min(img.shape[2], 32)
        crop_offset = (np.array(img.shape) - np.array([128, 128, C])) // 2
        return img[crop_offset[0]:crop_offset[0]+128, crop_offset[1]:crop_offset[1]+128, crop_offset[2]:crop_offset[2]+C]
    
    scale_factor = (voxel_size[0] / 1.5, voxel_size[1] / 1.5, voxel_size[2] / 3.15)
    scaled_image = rescale(image, scale_factor, order=3, mode='edge')
    scaled_image = (scaled_image - scaled_image.min()) / (scaled_image.max() - scaled_image.min())
    scaled_labels = rescale(labels, scale_factor, order=0, mode='edge')

    assert scaled_image.shape == scaled_labels.shape
    return center_crop(scaled_image), center_crop(scaled_labels)

torch.set_num_threads(16)
parser = argparse.ArgumentParser()
parser.add_argument('-m', type=str, default='configs.config_dirlab')
parser.add_argument('-n', type=str, default=None)
parser.add_argument('--model', type=str, default='cheb_kan')
parser.add_argument('--runs', type=int, default=2)

args = parser.parse_args()
cfg_module = importlib.import_module(args.m)
if args.n is not None:
    variable_name = args.n
elif args.model is not None:
    variable_name = 'config_' + args.model.lower()
else:
    raise Exception('Either config name or model name must be not None!')

curr_config = getattr(cfg_module, variable_name)
print(curr_config)
# multiple_seeds = np.random.choice(10 ** 6, size=args.runs, replace=False)
case_ids = range(101, 151)

acdc_path = "/srv/fast1/DIR_data/acdc"

patient_nums_0 = [[case_id, 0] for case_id in case_ids]
patient_nums_1 = [[case_id, 1] for case_id in case_ids]

patient_nums = patient_nums_0 + patient_nums_1
# patient_nums = [[101, 0], [101,1]]

overall_across_all = dict()
overall_regularity = []
overall_dice_arr = []
overall_runtime = []
overall_hd95 = []
peak_memory_consumption_mb = 0
per_patient_metrics = dict()
# preprocess_as_in_corr_mlp = False
preprocess_as_in_corr_mlp = True
exp = f'INR registration on ACDC, preprocessing={preprocess_as_in_corr_mlp}, model={args.model.lower()}, {args.runs} runs'
for run in range(args.runs):
    now = datetime.now()
    seed = now.microsecond
    print(f'--- Exp: {exp}, seed = {seed} ---')
    all_mean_dices = []
    
    for pair in patient_nums:
        patient_id = pair[0]
        order = pair[1]
        moving_img, fixed_img, moving_labels, fixed_labels, voxel_size = general.load_pair_acdc(acdc_path,
                                                                                    patient_id,
                                                                                    order
                                                                                    )

        if preprocess_as_in_corr_mlp:
            moving_img, moving_labels = preprocess_fn(moving_img, moving_labels, voxel_size)
            fixed_img, fixed_labels = preprocess_fn(fixed_img, fixed_labels, voxel_size)

        # mask_exp = moving_labels.astype('int')
        mask_exp = moving_labels.astype('int') + 10 # use all voxels for sampling
        mask_exp = mask_exp.astype('bool').astype('int')
        kwargs = copy.deepcopy(curr_config)
        # print(mask_exp.shape)
        # print(mask_exp.sum())
        kwargs["mask"] = mask_exp 
        kwargs['seed'] = seed
        ImpReg = models_3d.ImplicitRegistrator3d(torch.FloatTensor(moving_img),  
                                            torch.FloatTensor(fixed_img), 
                                            **kwargs)
        try: # delete it in final version. This is to avoid torch.autograd error when using Bendind Energy regularizer from IDIR
            t = time.time()
            ImpReg.fit()
            overall_runtime.append(time.time() - t)
        except Exception as e:
            print(f'AUTOGRAD ERROR, TRY AGAIN...')
            try:
                t = time.time()
                ImpReg.fit()
                overall_runtime.append(time.time() - t)
            except Exception as e:
                print(f'AUTOGRAD ERROR AGAIN! SKIP PATIENT {patient_id}...')
                continue

        peak_memory_consumption_mb = max(peak_memory_consumption_mb, general.get_gpu_used_memory())
        spacing = (1.5, 1.5, 3.15) if preprocess_as_in_corr_mlp else voxel_size
        mean_dice, dice_arr, hd95 = general.eval_segmentation_accuracy(ImpReg, fixed_img.shape, fixed_labels, moving_labels, voxel_size=spacing, device='cuda')
        folded_voxels_percent = general.compute_deformation_regularity(
            ImpReg.network, ImpReg.possible_coordinate_tensor, output_shape=fixed_img.shape
        )
        overall_regularity.append(folded_voxels_percent)
        all_mean_dices.append(mean_dice)
        overall_dice_arr.append(dice_arr)
        overall_hd95.append(hd95)

        pair_key = f"{patient_id}_{order}"
        if pair_key not in per_patient_metrics:
            per_patient_metrics[pair_key] = []
        per_patient_metrics[pair_key].append({
            "seed": seed,
            "mean_dice": float(mean_dice),
            "dice_array": [float(d) for d in dice_arr],
            "hd95": float(hd95),
            "folded_voxels_percent": float(folded_voxels_percent),
            "runtime_sec": time.time() - t,
            "peak_gpu_mem_mb": peak_memory_consumption_mb
        })

        print(f"Patient {patient_id}, order {order},  mean: {mean_dice:.3f}, std: {np.std(dice_arr):.3}, HD95: {hd95:.3f}, folded voxels %: {folded_voxels_percent:.5f}")
        
        pair_key = tuple(pair)
        if pair_key in overall_across_all.keys():
            overall_across_all[pair_key].append(mean_dice)
        else:
            overall_across_all[pair_key] = [mean_dice]

    print(f'Overall with seed = {seed}: {np.array(all_mean_dices).mean():.3f}')

print(f'Cfg: {curr_config}')
print(f'---- Exp: {exp}, FINAL RESULTS ----')
final_mean = 0
for pair_key in overall_across_all.keys():
    pair_mean = np.array(overall_across_all[pair_key]).mean()
    final_mean += pair_mean

final_dice_mean = final_mean / len(overall_across_all)
final_dice_std = np.concatenate(overall_dice_arr, axis=0).std()
print(f"Overall dice, mean: {final_dice_mean:.3f}")
print(f"Overall dice, std: {final_dice_std:.3f}")
hd95_overall = np.array(overall_hd95).mean()
hd95_overall_std = np.array(overall_hd95).std()
print(f"Overall HD95, mean: {hd95_overall:.3f}")
print(f"Overall HD95, std: {hd95_overall_std:.3f}")
folded_vox_percent_overall = np.array(overall_regularity).mean()
folded_vox_percent_overall_std = np.array(overall_regularity).std()
print(f'% of folded voxels, mean: {folded_vox_percent_overall:.5f}')
print(f'% of folded voxels, std: {folded_vox_percent_overall_std:.5f}')
print("Mean runtime: {:.2f} seconds".format(np.array(overall_runtime).mean()))
print(f'Peak GPU memory consumption: {peak_memory_consumption_mb} MB')

metrics_to_save = {
    "config": str(curr_config),
    "runs": args.runs,
    "dice_mean": final_dice_mean,
    "dice_std": final_dice_std,
    "hd95_mean": hd95_overall,
    "hd95_std": hd95_overall_std,
    "folded_voxels_percent_mean": folded_vox_percent_overall,
    "folded_voxels_percent_std": folded_vox_percent_overall_std,
    "mean_runtime_sec": float(np.array(overall_runtime).mean()),
    "peak_gpu_mem_mb": peak_memory_consumption_mb
}

metrics_to_save["per_patient"] = per_patient_metrics

output_filename = f"metrics_{args.model.lower()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_filename, 'w') as f:
    json.dump(metrics_to_save, f, indent=4)

print(f"Metrics saved to {output_filename}")