from utils import general
import torch
from models import models_3d
import cv2
import numpy as np
import os
import random 
import importlib
import argparse
import copy
from datetime import datetime
import time
import json
# exp = 'IDIR + an.TV.reg. with 20 seed'

torch.set_num_threads(16)
parser = argparse.ArgumentParser()
parser.add_argument('-m', type=str, default='configs.config_dirlab')
parser.add_argument('-n', type=str, default=None)
parser.add_argument('--model', type=str, default='cheb_kan')
parser.add_argument('--runs', type=int, default=2)

args = parser.parse_args()
cfg_module = importlib.import_module(args.m)
if args.n is not None:
    variable_name = args.n
elif args.model is not None:
    variable_name = 'config_' + args.model.lower()
else:
    raise Exception('Either config name or model name must be not None!')

curr_config = getattr(cfg_module, variable_name)
print(curr_config)
# multiple_seeds = np.random.choice(10 ** 6, size=args.runs, replace=False)


oasis_path = "/home/datasets/oasis_1_3d"
oasis_folders_path = os.path.join(oasis_path, "subjects.txt")

folders = []
with open(oasis_folders_path, 'r') as f:
    for line in f:
        folders.append(line.strip('\n'))

CASES = 50
if CASES == 50:
    case_ids = range(364, 413)
    # case_ids = range(364, 365)
    pair_list = [[i, i + 1] for i in case_ids]
else:
    case_ids = range(394, 414)
    pair_list = []
    for i in case_ids:
        for j in case_ids:
            if i != j:
                pair_list.append([i, j])

print(len(pair_list))

overall_across_all = dict()

overall_regularity = []
overall_dice_arr = []
overall_runtime = []
overall_hd95 = []
exp = f'INR registration on OASIS, model={args.model.lower()}'
peak_memory_consumption_mb = 0
per_patient_metrics = {}

for run in range(args.runs):
    now = datetime.now()
    seed = now.microsecond
    print(f'--- Exp: {exp}, seed = {seed} ---')
    all_mean_dices = []
    
    for pair in pair_list:
        moving_id, fixed_id = pair[0], pair[1]

        moving_img, fixed_img, moving_labels, fixed_labels = general.load_pair_oasis_3d(oasis_path,
                                                                                    folders,
                                                                                    moving_id,
                                                                                    fixed_id)

        mask_exp = moving_labels.astype('int')
        mask_exp = mask_exp.astype('bool').astype('int')
        kwargs = copy.deepcopy(curr_config)
        kwargs["mask"] = mask_exp 
        kwargs['seed'] = seed
        ImpReg = models_3d.ImplicitRegistrator3d(torch.FloatTensor(moving_img),  
                                            torch.FloatTensor(fixed_img), 
                                            **kwargs)
        
        torch.cuda.empty_cache()  # Clear cache before starting
        # torch.cuda.reset_peak_memory_stats()

        t = time.time()
        ImpReg.fit()
        overall_runtime.append(time.time() - t)

        peak_memory_consumption_mb = max(peak_memory_consumption_mb, general.get_gpu_used_memory())
        mean_dice, dice_arr, hd95 = general.eval_segmentation_accuracy(ImpReg, fixed_img.shape, fixed_labels, moving_labels, device='cuda')
        folded_voxels_percent = general.compute_deformation_regularity(
            ImpReg.network, ImpReg.possible_coordinate_tensor, output_shape=fixed_img.shape
        )
        overall_regularity.append(folded_voxels_percent)
        all_mean_dices.append(mean_dice)
        overall_dice_arr.append(dice_arr)
        overall_hd95.append(hd95)

        print(f"Pair: {moving_id} and {fixed_id},  mean: {mean_dice:.3f}, std: {np.std(dice_arr):.3}, HD95: {hd95:.3f}, folded voxels %: {folded_voxels_percent:.5f}")
        
        pair_key = (moving_id, fixed_id)
        if pair_key in overall_across_all.keys():
            overall_across_all[pair_key].append(mean_dice)
        else:
            overall_across_all[pair_key] = [mean_dice]


        key = f"{moving_id}_{fixed_id}"
        if key not in per_patient_metrics:
            per_patient_metrics[key] = []
        per_patient_metrics[key].append({
            "run_seed": seed,
            "mean_dice": float(mean_dice),
            "dice_array": [float(d) for d in dice_arr],
            "hd95": float(hd95),
            "folded_voxels_percent": float(folded_voxels_percent),
            "runtime_sec": time.time() - t,
            "voxel_size": getattr(ImpReg, 'voxel_size', None),
            "peak_gpu_mem_mb": peak_memory_consumption_mb
        })

    print(f'Overall with seed = {seed}: {np.array(all_mean_dices).mean():.3f}')

print(f'Cfg: {curr_config}')
print(f'---- Exp: {exp}, FINAL RESULTS ----')
final_mean = 0
for pair_key in overall_across_all.keys():
    pair_mean = np.array(overall_across_all[pair_key]).mean()
    # print(f"{pair_key[0]} and {pair_key[1]}: {pair_mean:.3f}")
    final_mean += pair_mean

final_dice_mean = final_mean / len(overall_across_all)
final_dice_std = np.concatenate(overall_dice_arr, axis=0).std()
print(f"Overall dice, mean: {final_dice_mean:.3f}")
print(f"Overall dice, std: {final_dice_std:.3f}")
hd95_overall = np.array(overall_hd95).mean()
hd95_overall_std = np.array(overall_hd95).std()
print(f"Overall HD95, mean: {hd95_overall:.3f}")
print(f"Overall HD95, std: {hd95_overall_std:.3f}")
folded_vox_percent_overall = np.array(overall_regularity).mean()
folded_vox_percent_overall_std = np.array(overall_regularity).std()
print(f'% of folded voxels, mean: {folded_vox_percent_overall:.5f}')
print(f'% of folded voxels, std: {folded_vox_percent_overall_std:.5f}')
print("Mean runtime: {:.2f} seconds".format(np.array(overall_runtime).mean()))
print(f'Peak GPU memory consumption: {peak_memory_consumption_mb} MB')


metrics_to_save = {
    "config": str(curr_config),
    "runs": args.runs,
    "dice_mean_overall": final_dice_mean,
    "dice_std_overall": final_dice_std,
    "hd95_mean_overall": hd95_overall,
    "hd95_std_overall": hd95_overall_std,
    "folded_vox_percent_mean_overall": folded_vox_percent_overall,
    "folded_vox_percent_std_overall": folded_vox_percent_overall_std,
    "mean_runtime_sec": float(np.array(overall_runtime).mean()),
    "peak_gpu_mem_mb": peak_memory_consumption_mb,
    "per_patient": per_patient_metrics
}

output_filename = f"metrics_oasis_{args.model.lower()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"


folder_path = "oasis_results"  
os.makedirs(folder_path, exist_ok=True)  

full_path = os.path.join(folder_path, output_filename)

with open(full_path, 'w') as f:
    json.dump(metrics_to_save, f, indent=4)

print(f"Metrics saved to {full_path}")