
import numpy as np


import torch
import torch.nn.functional as F
import SimpleITK as sitk
import os
import json
import random

# CN vs Early Cognitve Impairment
early_label_map = {
    '0.0':0,
    '0.5':1,
    '1.0':1,
    'CN': 0,
    'MCI': 1,
}

# CN vs Cognitve Impairment
label_map = {
    '0.0':0,
    '0.5':1,
    '1.0':1,
    '2.0':1,
    '3.0':1,
    'CN': 0,
    'MCI': 1,
    'AD': 1,
}

ad_templates = [
    # 阶段诊断类
    "Based on the cognitive assessment data: <img><non_img>, what is the current stage of Alzheimer's progression? (CN/MCI/AD)",
    "Analyze the neuroimaging and test results: <img><non_img> to determine if the patient is at CN, MCI, or AD stage.",
    "Classify the Alzheimer's disease stage for this patient profile: <img><non_img> as Cognitive Normal (CN), Mild Cognitive Impairment (MCI), or Alzheimer's Disease (AD).",
    
    # 二分类诊断类
    "Does the biomarker profile: <img><non_img> indicate presence of Alzheimer's Disease? ",
    "Based on the clinical presentation: <img><non_img>, should this patient be diagnosed with Alzheimer's Disease?",
    
    # 进展分析类
    "Evaluate the progression severity in this AD case: <img><non_img> and classify as CN, MCI, or AD.",
    "Given the input data: <img><non_img>, has the patient progressed to AD or remains at MCI/CN stage?",
    
    # 鉴别诊断类
    "Differentiate between normal aging (CN), MCI and AD based on: <img><non_img>. Provide the most likely diagnosis.",
    "Are the cognitive changes in <img><non-img> consistent with CN, MCI, or full Alzheimer's Dementia?",
    
    # 风险预测类
    "Predict the current AD stage (CN/MCI/AD) for this high-risk patient: <img><non_img>",
    "Assess the likelihood of progression to AD based on: <img><non_img>."
]

class FinetuneDataset(torch.utils.data.Dataset):
  """
  Loads data and corresponding label and returns pytorch float tensor.
  """
  def __init__(self, data,tokenizer,max_words=100):
    self.files = data
    self.tokenizer = tokenizer
    self.max_words = max_words

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    """
    Read data and label and return them.
    """
    img_path = self.files[idx]['img_path']
    data = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).astype(np.float32)
    data = torch.FloatTensor(data)
    # 归一化到[0,1]
    data = F.normalize(data)    
    image = F.interpolate(data.unsqueeze(0).unsqueeze(0),size=(128,128,128),mode='trilinear').squeeze()

    img_finding = self.files[idx]['img_finding']
    text = self.files[idx]['text']
    label = self.files[idx]['label']
    reasoning = self.files[idx]['reasoning']

    # Tokenization
    non_image = img_finding + text
    selected_template = random.choice(ad_templates)
    input1 = selected_template#.format(input_data=input_data)
    input2 = input1 + reasoning  # Concatenate question and answer

    
    input1_tensor = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False, allowed_special={"<img>"}), dtype=torch.int64)
    input2_tensor = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True, allowed_special={"<img>"}), dtype=torch.int64)
       
    # Padding
    # print('length({}) and max_words({})'.format(input2_tensor.shape[0],self.max_words))
    padding = self.max_words - input2_tensor.shape[0]
    if padding > 0:
        input2_tensor = F.pad(input2_tensor, (0, padding), "constant", self.tokenizer.pad_id)  # Pad the tensor if its length is smaller than self.max_words
    elif padding < 0:
        print('Truncate the tensor! length({}) is larger than max_words({})'.format(input2_tensor.shape[0],self.max_words))
        input2_tensor = input2_tensor[:self.max_words]  # Truncate the tensor if its length is larger than self.max_words
    input2_labels = input2_tensor.clone()
    input2_labels[:len(input1_tensor)] = -1  # Ignore the question part in labels

    input2_mask = input2_tensor.ge(0)
    label_mask = input2_labels.ge(0)
    input2_tensor[~input2_mask] = 0
    input2_labels[~label_mask] = 0
    
    input2_mask = input2_mask.float()
    label_mask = label_mask.float()


    return input2_tensor, input2_labels, input2_mask, image, non_image, label

class RLDataset(torch.utils.data.Dataset):
  """
  Loads data and corresponding label and returns pytorch float tensor.
  """
  def __init__(self, data,tokenizer,max_words=100):
    self.files = data
    self.tokenizer = tokenizer
    self.max_words = max_words

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    """
    Read data and label and return them.
    """
    img_path = self.files[idx]['img_path']
    data = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).astype(np.float32)
    data = torch.FloatTensor(data)
    # 归一化到[0,1]
    data = F.normalize(data)    
    image = F.interpolate(data.unsqueeze(0).unsqueeze(0),size=(128,128,128),mode='trilinear').squeeze()

    img_finding = self.files[idx]['img_finding']
    text = self.files[idx]['text']
    label = self.files[idx]['label']
    reasoning = self.files[idx]['reasoning']

    # Tokenization
    non_image = img_finding + text
    selected_template = ad_templates[idx%11]
    input1 = selected_template#.format(input_data=input_data)
    input2 = input1 + reasoning  # Concatenate question and answer

    
    input1_tensor = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False, allowed_special={"<img>"}), dtype=torch.int64)
    input2_tensor = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True, allowed_special={"<img>"}), dtype=torch.int64)
       
    # Padding
    # print('length({}) and max_words({})'.format(input2_tensor.shape[0],self.max_words))
    padding = self.max_words - input2_tensor.shape[0]
    if padding > 0:
        input2_tensor = F.pad(input2_tensor, (0, padding), "constant", self.tokenizer.pad_id)  # Pad the tensor if its length is smaller than self.max_words
    elif padding < 0:
        print('Truncate the tensor! length({}) is larger than max_words({})'.format(input2_tensor.shape[0],self.max_words))
        input2_tensor = input2_tensor[:self.max_words]  # Truncate the tensor if its length is larger than self.max_words
    input2_labels = input2_tensor.clone()
    input2_labels[:len(input1_tensor)] = -1  # Ignore the question part in labels

    input2_mask = input2_tensor.ge(0)
    label_mask = input2_labels.ge(0)
    input2_tensor[~input2_mask] = 0
    input2_labels[~label_mask] = 0
    
    input2_mask = input2_mask.float()
    label_mask = label_mask.float()

    return {
        'input_ids': input2_tensor,
        'labels': input2_labels,
        'input_mask': input2_mask,
        'question': input1,
        "answer": reasoning,
        "prompt_ids": input1_tensor,
        'img': image,
        'non_img': non_image,
    }



def get_dataset(datalist=['ADNI-train'],task='CNvsMCI',tokenizer=None,max_words=100):
    if task == 'CNvsCI':
        LABEL_MAP = label_map
    elif task == 'CNvsMCI':
        LABEL_MAP = early_label_map
    files = []
    for data in datalist:
        filename = f'./local_data/{data}.csv'
        dataset_name = data.split('-')[0]
        reasoning_dir = os.path.join('./local_data/alzheimer_diagnosis_results',dataset_name)
        print('load data from', filename)
        with open(filename) as f:
            lines = f.readlines()
            for line in lines:
                img_path,img_finding, text, diagnosis = line.strip('\n').split('\t')
                # import pdb;pdb.set_trace()
                if 'adni' in img_path:
                    ptid, time = img_path.split('/')[-3:-1]
                    reasoning_file = os.path.join(reasoning_dir,ptid+'_'+time+'.txt')
                else:
                    ptid_time = img_path.split('/')[-1].split('.')[0]
                    reasoning_file = os.path.join(reasoning_dir,ptid_time+'.txt')
                
                with open(reasoning_file, 'r') as file:  # 打开文件
                    reasoning_data = json.load(file)  # 解析JSON
                    reasoning = reasoning_data['gpt_diagnosis'].replace('\n','').replace('Reasoning','<reasoning>').replace('Diagnosis','</reasoning><diagnosis>').replace('Confidence','</diagnosis><confidence>')

                name = diagnosis.split('Diagnosis: ')[1].strip()[:-1]
                if name not in LABEL_MAP.keys():
                    continue
                else:
                    label = LABEL_MAP[name]
                

                files.append(
                    {
                        'img_path': img_path,
                        'img_finding':img_finding,
                        'text': text,
                        'label': label,
                        'reasoning': reasoning,
                    }
                )

    dataset = FinetuneDataset(data=files,tokenizer=tokenizer,max_words=max_words)
    return dataset

def get_rl_dataset(datalist=['ADNI-train'],task='CNvsMCI',tokenizer=None,max_words=100):
    if task == 'CNvsCI':
        LABEL_MAP = label_map
    elif task == 'CNvsMCI':
        LABEL_MAP = early_label_map
    files = []
    for data in datalist:
        filename = f'./local_data/{data}.csv'
        dataset_name = data.split('-')[0]
        reasoning_dir = os.path.join('./local_data/alzheimer_diagnosis_results',dataset_name)
        print('load data from', filename)
        with open(filename) as f:
            lines = f.readlines()
            for line in lines:
                img_path,img_finding, text, diagnosis = line.strip('\n').split('\t')
                # import pdb;pdb.set_trace()
                if 'adni' in img_path:
                    ptid, time = img_path.split('/')[-3:-1]
                    reasoning_file = os.path.join(reasoning_dir,ptid+'_'+time+'.txt')
                else:
                    ptid_time = img_path.split('/')[-1].split('.')[0]
                    reasoning_file = os.path.join(reasoning_dir,ptid_time+'.txt')
                
                with open(reasoning_file, 'r') as file:  # 打开文件
                    reasoning_data = json.load(file)  # 解析JSON
                    reasoning = reasoning_data['gpt_diagnosis'].replace('\n','').replace('Reasoning','<reasoning>').replace('Diagnosis','</reasoning><diagnosis>').replace('Confidence','</diagnosis><confidence>')

                name = diagnosis.split('Diagnosis: ')[1].strip()[:-1]
                if name not in LABEL_MAP.keys():
                    continue
                else:
                    label = LABEL_MAP[name]
                

                files.append(
                    {
                        'img_path': img_path,
                        'img_finding':img_finding,
                        'text': text,
                        'label': label,
                        'reasoning': reasoning,
                    }
                )

    dataset = RLDataset(data=files,tokenizer=tokenizer,max_words=max_words)
    return dataset
