'''
Code for registration of all reconstructions
'''
import os
cache_dir = ...
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = cache_dir
import argparse
import torch
import json
import openai
from vlm_score_utils import normalize_percentile,M3D_LaMed_Phi_vlm_scoring
from prompts import PROMPTS
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F

def main(args):
    dtype = torch.float16 # or bfloat16, float16, float32
    device = torch.device(f'cuda:{args.gpu}')
    proj_out_num = 256
    desired_image_size = (128, 256, 256)
    # Load the VLM:
    model = AutoModelForCausalLM.from_pretrained(
        args.vlm_model,
        torch_dtype=dtype,
        trust_remote_code=True,
        cache_dir=cache_dir
        )
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(
        args.vlm_model,
        model_max_length=512,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
        cache_dir=cache_dir
    )

    model = model.to(device=device)


    severity_level_score = {'no motion':0, 'mild':1, 'moderate':2, 'severe':3}
    image_save_path = os.path.join(args.npy_save_base_path, 'Med3DVLM_npy_resize')
    baseline_list = ['AltOpt','MotionTTT','stacked_unet']
    score_record = {}
    for baseline in baseline_list:
        score_record[baseline] = {}
        os.makedirs(os.path.join(image_save_path, baseline), exist_ok=True)
        for i in range(8):
            for j in range(3):
                scan_id = f"S{i+1}_{j+1}"
                volume = torch.load(os.path.join(args.recon_save_path,baseline, f"{scan_id}.pt"), map_location='cpu')
                volume = normalize_percentile(volume.squeeze().numpy())
                volume = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0)
                vlm_volume = F.interpolate(
                    volume,
                    size=desired_image_size,
                    mode='trilinear',
                    align_corners=False
                ).squeeze(0)
                # Save the vlm_volume as a numpy file:
                np.save(os.path.join(image_save_path, baseline, f"{scan_id}.npy"), vlm_volume)
                npy_path = os.path.join(image_save_path, baseline, f"{scan_id}.npy")

                gpt_output_save_path = os.path.join(args.npy_save_base_path, f'Med3DVLM_output_short', baseline,scan_id)
                # Create the directory if it doesn't exist
                os.makedirs(gpt_output_save_path, exist_ok=True)
                severity_list = []
                for num_exp in range(5):
                    for attempt in range(1000):
                        try:
                            sl = M3D_LaMed_Phi_vlm_scoring(npy_path, gpt_output_save_path, 0.5, num_exp, tokenizer=tokenizer, model=model, device=device)
                            severity_list.append(severity_level_score[sl.lower()])
                            break
                        except Exception as e:
                            print(f"Attempt {attempt + 1} failed: {e}")          
                score_record[baseline][scan_id] = sum(severity_list) / len(severity_list)
    os.makedirs(args.score_save_path, exist_ok=True)
    with open(os.path.join(args.score_save_path, f"Med3DVLM_scores_short.json"), 'w') as f:
            json.dump(score_record, f, indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Register all reconstruction volumes")
    parser.add_argument("--recon_save_path", type=str, required=True, help="Path to reconstruction volumes")
    parser.add_argument("--npy_save_base_path", type=str, required=True, help="Base folder to save numpy files")
    parser.add_argument("--score_save_path", type=str, required=True, help="Path for saving VLM scores")
    parser.add_argument("--gpu", type=int, default=0, help="GPU device ID to use")
    parser.add_argument("--vlm_model", type=str, default="MagicXin/Med3DVLM-Qwen-2.5-7B", help="VLM model to use for scoring")
    parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for model files")
    args = parser.parse_args()

    main(args)
