from utils import general, metrics, visualization, spatial_transformer
from models import models_3d
import numpy as np
import torch
import torch.nn.functional as F
import importlib
import argparse
import os
import copy
from datetime import datetime
import time
import matplotlib.pyplot as plt
import json


torch.set_num_threads(16)

data_dir = "/srv/fast1/DIR_data/dirlab/"
out_dir = "output"

parser = argparse.ArgumentParser()
parser.add_argument("--savepath", type=str,
                        dest="savepath", default='./result/dirlab_test',
                        help="path for saving results")
parser.add_argument('-m', type=str, default='configs.config_dirlab')
parser.add_argument('-n', type=str, default=None)
parser.add_argument('--model', type=str, default='cheb_kan')
parser.add_argument('--runs', type=int, default=2)
parser.add_argument("--save_imgs", type=bool,
                        dest="save_imgs", default=False,
                        help="Save images during evaluation")

parser.add_argument("--visualize", type=bool, dest="visualize", default=False,
                        help="visualize images during evaluation")


args = parser.parse_args()
cfg_module = importlib.import_module(args.m)
if args.n is not None:
    variable_name = args.n
elif args.model is not None:
    variable_name = 'config_' + args.model.lower()
else:
    raise Exception('Either config name or model name must be not None!')

config_name = os.path.splitext(os.path.basename(args.model))[0]  # Extract config name from path
args.savepath = os.path.join(args.savepath, config_name) 

if not os.path.isdir(args.savepath):
    os.makedirs(args.savepath)

curr_config = getattr(cfg_module, variable_name)
print(curr_config)
case_ids = range(1, 11)
# case_ids = [1,2]

overall_across_all_runs = dict()
overall_diffs_all_runs = dict()
overall_regularity_runs = []
overall_runtime_runs = []
peak_memory_consumption_mb_runs = 0

for run in range(args.runs):

    seed = datetime.now().microsecond
    run_dir = os.path.join(args.savepath, f'seed_{seed}')
    os.makedirs(run_dir, exist_ok=True)
    print(f'--- Exp seed = {seed} ---')

    overall_across_all = {}
    overall_diffs = {}
    overall_regularity = []
    overall_runtime = []
    peak_memory_consumption_mb = 0
    exp = f'INR registration on DIR-LAB, model={args.model.lower()}'

    for case_id in case_ids:

        case_dir = os.path.join(run_dir, f'Case{case_id}')
        os.makedirs(case_dir, exist_ok=True)
        (
            img_insp,
            img_exp,
            landmarks_insp,
            landmarks_exp,
            mask_exp,
            mask_insp,
            voxel_size,
        ) = general.load_image_DIRLab(case_id, "{}/Case".format(data_dir))

        kwargs = copy.deepcopy(curr_config)
        kwargs["mask"] = mask_exp  
        kwargs['seed'] = seed
        embed = np.load(f'/home/IDIR/data/dirlab/Case{case_id}Pack/embed_insp_slic.npy')
        kwargs['embed'] = torch.FloatTensor(embed)
        # local_feat = np.load(f'/home/IDIR/data/dirlab/Case{case_id}Pack/feats_insp.npy')
        # local_feat = (local_feat - local_feat.min()) / (local_feat.max() - local_feat.min())
        # if local_feat.shape[0] == 1:
        #     local_feat = local_feat[0]
        # kwargs['local_feat'] = torch.FloatTensor(local_feat).permute(1, 2, 3, 0)
        torch.cuda.empty_cache() 
        t = time.time()
        ImpReg = models_3d.ImplicitRegistrator3d(img_exp, img_insp, **kwargs) 
        ImpReg.fit()
        peak_memory_consumption_mb = max(peak_memory_consumption_mb, general.get_gpu_used_memory())
        peak_memory_consumption_mb_runs = max(
            peak_memory_consumption_mb_runs, 
            general.get_gpu_used_memory() 
        )
        overall_runtime.append(time.time() - t)
        new_landmarks_orig, delta = general.compute_landmarks(
           ImpReg.network, landmarks_insp, image_size=img_insp.shape
        ) 

        accuracy_mean, accuracy_std, all_accuracies = general.compute_landmark_accuracy(
            new_landmarks_orig, landmarks_exp, voxel_size=voxel_size
        )
        
        folded_voxels_percent = general.compute_deformation_regularity(
            ImpReg.network, ImpReg.possible_coordinate_tensor, output_shape=img_insp.shape
        )
    

        overall_regularity.append(folded_voxels_percent)
        print("Case id: {} mean: {} std: {} folded % {:.5f}".format(case_id, accuracy_mean[0], accuracy_std[0], folded_voxels_percent))
        
        if case_id in overall_across_all.keys():
            overall_across_all[case_id].append(accuracy_mean[0])
            overall_diffs[case_id].append(all_accuracies)
        else:
            overall_across_all[case_id] = [accuracy_mean[0]]
            overall_diffs[case_id] = [all_accuracies]

        
        
        H, W, D = img_exp.shape
        z, y, x = np.meshgrid(
            np.arange(D), 
            np.arange(W),  
            np.arange(H),  
            indexing='ij'  
        )

        grid = np.stack([x, y, z], axis=-1).reshape(-1, 3)
        _, df = general.compute_landmarks_batch(ImpReg.network, grid, image_size=img_insp.shape) 


        df = torch.from_numpy(df) 
        df = df.reshape((D, W, H, 3)).permute(2, 1, 0, 3)
        df = df.unsqueeze(0)
        df = df.permute(0, 4, 1, 2, 3)

        mask_tensor = torch.from_numpy(mask_exp).bool() 
        mask_exp_new = mask_tensor.unsqueeze(0).unsqueeze(0)
        df *= mask_exp_new

        df_np = df.cpu().numpy()
        np.save(os.path.join(case_dir, 'deformation_field.npy'), df_np)


        if not isinstance(img_insp, np.ndarray):
            fixed = np.array(img_insp)
        else:
            fixed = img_insp
        
        if not isinstance(img_exp, np.ndarray):
            moving = np.array(img_exp)
        else:
            moving = img_exp

        kp_fixed = landmarks_insp
        kp_mov = landmarks_exp
        kp_fixed_warped = new_landmarks_orig
        im_shape = moving.shape
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        ST = spatial_transformer.SpatialTransformer(im_shape).to(device)


        moving_st = torch.from_numpy(moving).to(device).float()
        df = df.to(device).float()
        moving_st = moving_st.unsqueeze(0).unsqueeze(0)

        warped_moving, _ = ST(moving_st, df, return_phi=True)
        warped_moving = warped_moving.squeeze(0).squeeze(0)

        warped_np = warped_moving.cpu().numpy()
        np.save(os.path.join(case_dir, 'fixed.npy'), fixed)
        np.save(os.path.join(case_dir, 'warped_moving.npy'), warped_np)

        if args.save_imgs or args.visualize:

            histogram_save_path =os.path.join(case_dir, 'accuracy_histogram.png') if args.save_imgs else None
            visualization.plot_accuracy_histogram(all_accuracies, case_id, save_path=histogram_save_path)

            kp_overlay_Y_save_path = os.path.join(case_dir, 'kp_overlay_Y.png') if args.save_imgs else None
            visualization.visualize_dirlab_warp_fixed(warped_moving, fixed, case_id, 
                        kp_fixed, new_landmarks_orig, kp_mov, axis=1, voxel_size=voxel_size, visualize=args.visualize, save_path=kp_overlay_Y_save_path)

            kp_overlay_X_save_path = os.path.join(case_dir, 'kp_overlay_X.png') if args.save_imgs else None
            visualization.visualize_dirlab_warp_fixed(warped_moving, fixed, case_id, 
                        kp_fixed, new_landmarks_orig, kp_mov, axis=2, voxel_size=voxel_size, visualize=args.visualize, save_path=kp_overlay_X_save_path)

            kp_overlay_Z_save_path = os.path.join(case_dir, 'kp_overlay_Z.png') if args.save_imgs else None
            visualization.visualize_dirlab_warp_fixed(warped_moving, fixed, case_id, 
                        kp_fixed, new_landmarks_orig, kp_mov, axis=0, voxel_size=voxel_size, visualize=args.visualize, save_path=kp_overlay_Z_save_path)


            overlay_init_save_path_Y = os.path.join(case_dir, 'overlay_init_save_path_Y.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, warped_moving, case_id, axis=1, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Warped Moving", save_path=overlay_init_save_path_Y, visualize=args.visualize)

            overlay_init_save_path_X = os.path.join(case_dir, 'overlay_init_save_path_X.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, warped_moving, case_id, axis=2, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Warped Moving", save_path=overlay_init_save_path_X, visualize=args.visualize)

            overlay_init_save_path_Z = os.path.join(case_dir, 'overlay_init_save_path_Z.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, warped_moving, case_id, axis=0, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Warped Moving", save_path=overlay_init_save_path_Z, visualize=args.visualize)

            overlay_reg_save_path_Y = os.path.join(case_dir, 'overlay_reg_save_path_Y.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, moving, case_id, axis=1, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Moving", save_path=overlay_reg_save_path_Y, visualize=args.visualize)

            overlay_reg_save_path_X = os.path.join(case_dir, 'overlay_reg_save_path_X.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, moving, case_id, axis=2, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Moving", save_path=overlay_reg_save_path_X, visualize=args.visualize)

            overlay_reg_save_path_Z = os.path.join(case_dir, 'overlay_reg_save_path_Z.png') if args.save_imgs else None
            visualization.overlay_images_dirlab(
                fixed, moving, case_id, axis=0, voxel_size=voxel_size, 
                red_title="Fixed", green_title="Moving", save_path=overlay_reg_save_path_Z, visualize=args.visualize)

            collage_save_path = os.path.join(case_dir, 'collage.png') if args.save_imgs else None
            visualization.collage_images_dirlab(
                moving, warped_moving, fixed, case_id, axis=1, voxel_size=voxel_size, save_path=collage_save_path, visualize=args.visualize)


            kp_overlay_mask_save_path = os.path.join(case_dir, 'kp_overlay_mask.png') if args.save_imgs else None
            visualization.visualize_dirlab_warp_fixed_with_mask(warped_moving, fixed, mask_exp, case_id, 
                        kp_fixed, new_landmarks_orig, kp_mov, axis=1, voxel_size=voxel_size, visualize=args.visualize, save_path=kp_overlay_mask_save_path)


        metrics_dict = {
            "accuracy_mean_mm": float(accuracy_mean[0]),
            "accuracy_std_mm": float(accuracy_std[0]),
            "folded_voxels_percent": float(folded_voxels_percent),
            "peak_memory_mb": peak_memory_consumption_mb,
            "all_accuracies_mm": [float(x) for x in all_accuracies]
        }
        metrics_path = os.path.join(case_dir, 'metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump(metrics_dict, f, indent=4)
        print(f"Saved metrics for Case {case_id} to {metrics_path}")
        # ----------------------------------------------

        kps_dict = {
            'kp_fixed': landmarks_insp.tolist(),
            'kp_mov': landmarks_exp.tolist(),
            'kp_fixed_warped': new_landmarks_orig.tolist()
        }
        with open(os.path.join(case_dir, 'keypoints.json'), 'w') as f:
            json.dump(kps_dict, f, indent=4)
        print(f"Saved keypoints for Case {case_id} to keypoints.json")

        overall_regularity_runs.append(folded_voxels_percent)

        overall_runtime_runs.append(time.time() - t)
        if case_id in overall_across_all_runs.keys():
            overall_across_all_runs[case_id].append(accuracy_mean[0])
            overall_diffs_all_runs[case_id].append(all_accuracies)
        else:
            overall_across_all_runs[case_id] = [accuracy_mean[0]]
            overall_diffs_all_runs[case_id] = [all_accuracies]

    summary = {
        'seed': seed,
        'mean_accuracy_mm': float(np.mean(list(overall_across_all.values()))),
        'std_accuracy_mm': float(np.std(np.concatenate(list(overall_diffs.values())))),
        'mean_folded_pct': float(np.mean(overall_regularity)),
        'std_folded_pct': float(np.std(overall_regularity)),
        'peak_memory_mb': peak_memory_consumption_mb
    }
    summary_path = os.path.join(run_dir, 'summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=4)
    print(f"Saved summary to {summary_path}")
    # ----------------------------------------------


print(f'Cfg: {curr_config}')
print(f'---- Exp: {exp}, results after {args.runs} runs ----')
final_mean = 0
final_accuracies = []
for key in overall_across_all_runs.keys():
    case_mean = np.array(overall_across_all_runs[key]).mean()
    final_case_accs = np.concatenate(overall_diffs_all_runs[key])
    case_std = np.std(final_case_accs)
    print(f'Case id: {key}, mean: {case_mean:.2f}, std: {case_std:.2f}, ')
    final_mean += case_mean
    final_accuracies.append(final_case_accs)
print(f'Overall mean: {(final_mean / len(case_ids)):.3f}, overall std: {np.std(np.concatenate(final_accuracies)):.2f}')
folded_vox_percent_overall = np.array(overall_regularity_runs).mean()
folded_vox_percent_overall_std = np.array(overall_regularity_runs).std()
print(f'% of folded voxels, mean: {folded_vox_percent_overall:.5f}')
print(f'% of folded voxels, std: {folded_vox_percent_overall_std:.5f}')
print("Mean runtime: {:.2f}".format(np.array(overall_runtime_runs).mean()))
print(f'Peak GPU memory consumption: {peak_memory_consumption_mb_runs} MB')