import numpy as np
import os
import matplotlib.pyplot as plt
import openai
from openai import OpenAI
import base64
import json
import re
from prompts import PROMPTS
import torch

def normalize_percentile(img, lower_percentile=1, upper_percentile=99.9, clip=True):
    """ Normalization to the lower and upper percentiles 
        Utility functions from:
        https://github.com/melanieganz/ImageQualityMetricsMRI/blob/main/utils/data_utils.py

    """
    img = img.astype(np.float32)
    lower = np.percentile(img, lower_percentile)
    upper = np.percentile(img, upper_percentile)
    img = (img - lower) / (upper - lower)
    if clip:
        img = np.clip(img, 0, 1)
    return img

def image_generation(recon, save_path):
    recon = normalize_percentile(recon.numpy())
    slice_indices = [80, 108, 151]
    views = ['Axial', 'Sagittal', 'Coronal']
    n_slices = len(slice_indices)
    figsize = (n_slices * 4, len(views) * 4)

    fig, axes = plt.subplots(len(views), n_slices, figsize=figsize)
    plt.subplots_adjust(wspace=0.03, hspace=0.04)

    for view_idx, view in enumerate(views):
        for slice_idx_idx, slice_idx in enumerate(slice_indices):
            ax = axes[view_idx, slice_idx_idx]
            if view == 'Axial':
                recon_slice = recon[slice_idx, :, :]
            elif view == 'Sagittal':
                recon_slice = recon[:, :, slice_idx]
            elif view == 'Coronal':
                recon_slice = recon[:, slice_idx, :]
            ax.imshow(recon_slice, cmap='gray')
            ax.axis('off')

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def chatgpt_scoring(image_path,json_folder, temperature, expert_num, prompt=None, api_key=None):
    openai.api_key = api_key
    with open(image_path, "rb") as f:
        image_bytes = f.read()
    image_b64_str = base64.b64encode(image_bytes).decode("utf-8")
    if prompt is None:
        prompt = PROMPTS['long']
    response = openai.chat.completions.create(
                model="gpt-4o",
                messages=[
        {
                        "role": "system",
                        "content": (
                            "You are an MRI image analysis expert whose task is to evaluate "
                            "the severity of motion artifacts in MRI images. This is not medical advice."
                        )
                    },
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": "Here is the MRI image for evaluation :"},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{image_b64_str}"
                                },
                            },
                            {"type": "text", "text": prompt},
                        ],
                    }
                ],
                max_tokens=600,  
                temperature=temperature,
            )

    gpt_out = response.choices[0].message.content
    print(gpt_out)
    analysis_text = gpt_out.strip()
    score_line = analysis_text.split("\n")[-1]

    pattern = re.compile(
        r"(?:Severity\s*Level:\s*)?(No\s*Motion|Mild|Moderate|Severe)",
        re.IGNORECASE
    )
    matches = pattern.findall(score_line)
    print(matches)
    output_data = {
        "text": analysis_text,
        "Severity Level": matches[-1]
    }

    # Save the output as json file
    output_path = os.path.join(json_folder, f'eval_temp{temperature}_expnum{expert_num}.json')
    with open(output_path, 'w', encoding='utf-8') as json_file:
        json.dump(output_data, json_file, ensure_ascii=False, indent=4)
    return output_data["Severity Level"]


def qwen_vl_scoring(image_path,json_folder, temperature, expert_num, prompt=None,api_key=None):
    with open(image_path, "rb") as f:
        image_bytes = f.read()
    image_b64_str = base64.b64encode(image_bytes).decode("utf-8")
    if prompt is None:
        prompt = PROMPTS['long']

    client = OpenAI(
        api_key=api_key,
        base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    )

    response = client.chat.completions.create(
                model="qwen-vl-max",
                messages=[
        {
                        "role": "system",
                        "content": (
                            "You are an MRI image analysis expert whose task is to evaluate "
                            "the severity of motion artifacts in MRI images. This is not medical advice."
                        )
                    },
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": "Here is the MRI image for evaluation :"},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{image_b64_str}"
                                },
                            },
                            {"type": "text", "text": prompt},
                        ],
                    }
                ],
                max_tokens=600,  
                temperature=temperature,
            )

    gpt_out = response.choices[0].message.content
    print(gpt_out)
    analysis_text = gpt_out.strip()
    score_line = analysis_text.split("\n")[-1]

    pattern = re.compile(
        r"(?:Severity\s*Level:\s*)?(No\s*Motion|Mild|Moderate|Severe)",
        re.IGNORECASE
    )
    matches = pattern.findall(score_line)
    print(matches)
    output_data = {
        "text": analysis_text,
        "Severity Level": matches[-1]
    }

    # Save the output as json file
    output_path = os.path.join(json_folder, f'eval_temp{temperature}_expnum{expert_num}.json')
    with open(output_path, 'w', encoding='utf-8') as json_file:
        json.dump(output_data, json_file, ensure_ascii=False, indent=4)
    return output_data["Severity Level"]


def M3D_LaMed_Phi_vlm_scoring(image_path,json_folder, temperature, expert_num, prompt=None,tokenizer=None,model=None,device=None):
    if prompt is None:
        prompt = PROMPTS['long']
    proj_out_num = 256
    dtype = torch.float16
    image_tokens = "<im_patch>" * proj_out_num
    input_txt = image_tokens + prompt
    input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)

    image_np = np.load(image_path)
    image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)

    generation = model.generate(image_pt, input_id, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=temperature)

    gpt_out = tokenizer.batch_decode(generation, skip_special_tokens=True)[0]
    print(gpt_out)

    if gpt_out in ['No Motion', 'Mild', 'Moderate', 'Severe']:
        output_data = {
            "text": gpt_out,
            "Severity Level": gpt_out
        }
    else:
        analysis_text = gpt_out.strip()
        score_line = analysis_text.split("\n")[-1]

        pattern = re.compile(
            r"(?:Severity\s*Level:\s*)?(No\s*Motion|Mild|Moderate|Severe)",
            re.IGNORECASE
        )
        matches = pattern.findall(score_line)
        print(matches)
        output_data = {
            "text": analysis_text,
            "Severity Level": matches[-1]
        }

    # Save the output as json file
    output_path = os.path.join(json_folder, f'eval_temp{temperature}_expnum{expert_num}.json')
    with open(output_path, 'w', encoding='utf-8') as json_file:
        json.dump(output_data, json_file, ensure_ascii=False, indent=4)
    return output_data["Severity Level"]