import pandas as pd
from glob import glob
from collections import Counter
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import imageio
import numpy as np
import warnings
from scipy.ndimage import gaussian_filter
import re
import os
import json
warnings.filterwarnings("ignore")

def generate_heatmap(base_image, width, height, x_ratio, y_ratio, gaze_data, radius=5, _x='x_position', _y='y_position', fps=10):
    data=gaze_data.to_dict(orient='records')
    # Create a blank image with the specified dimensions
    heatmap_image = Image.new('RGBA', (width, height), (0, 0, 0, 0))
    draw = ImageDraw.Draw(heatmap_image)
    
    # Generate the intensity map from fixation data
    intensity_map = np.zeros((height, width))
    frames = []
    frames_captioned = []
    for gaze in data:
        x, y = gaze[_x]*x_ratio, gaze[_y]*y_ratio
        x=int(x)
        y=int(y)
        if y>=height-radius or y<0+radius  or x<0+radius  or x>=width-radius:
            continue
        intensity_map[y][x] += gaze['Time (in secs)']
        num_frames = int(round(gaze['Time (in secs)'] * fps))
        for _ in range(num_frames):
            frame_image = base_image.copy()
            frame_draw = ImageDraw.Draw(frame_image)
            frame_draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill="red")
            frames.append(np.array(frame_image))
        
    # Apply Gaussian blur to smooth out the heatmap
    intensity_map = gaussian_filter(intensity_map, sigma=radius)
    
    # Normalize intensity map
    intensity_map /= np.max(intensity_map)
    
    # Apply colors according to intensity map
    for y in range(height):
        for x in range(width):
            intensity = min(255, max(0, int(intensity_map[y][x] * 255)))
            draw.point((x, y), fill=(255 - intensity, 0, 0, intensity))
    
    return heatmap_image, frames

def process_f(f, width, height, mode='REFLACX'):
    fd=f.to_dict(orient='records')
    text = ""
    if mode=='REFLACX':
        for item in fd:
            text += f"X: {item['x_position']/width:.3f}, Y: {item['y_position']/height:.3f}, Fixation Time: {item['Time (in secs)']:.1f} seconds\n"
    else:
        for item in fd:
            text += f"X: {item['X_ORIGINAL']/width:.3f}, Y: {item['Y_ORIGINAL']/height:.3f}, Fixation Time: {item['Time (in secs)']:.1f} seconds\n"
    return text


class GazeVideoGenerator:
    def __init__(self, 
                 mimic_eye_path='physionet.org/files/mimic-eye',
                 mimic_cxr_path='physionet.org/files/mimic-cxr/2.0.0'
                ):
        self.mimic_eye_path=mimic_eye_path
        self.mimic_cxr_path=mimic_cxr_path
        mimic_eye_images_path=os.path.join(mimic_eye_path,'patient_*/CXR-JPG/s*')
        meta_file=os.path.join(mimic_eye_path,'spreadsheets/cxr_meta.csv')
        cxr_split_file=os.path.join(mimic_eye_path,'spreadsheets/CXR-JPG/cxr_split.csv')
        cxr_reports_file=os.path.join(mimic_cxr_path,'mimic-cxr-sections/mimic_cxr_sectioned.csv')
        cxr_chexpert_file=os.path.join(mimic_cxr_path,'mimic-cxr-2.0.0-chexpert.csv')
        self.meta_df = pd.read_csv(meta_file)
        remove_list=[k for k,v in Counter(meta_df['subject_id'].tolist()).items() if v>1]
        self.set_index('subject_id', inplace=True)
        self.remove_list=[str(i) for i in remove_list]
        self.meta_df.drop(index=remove_list, inplace=True)
        self.patient_subjects = glob('files/mimic-eye/patient_*/CXR-JPG/s*')
        self.cxr_reports = pd.read_csv(cxr_reports_file)
        self.cxr_reports.fillna('', inplace=True)
        self.cxr_reports.set_index('study', inplace=True)
        self.cxr_split = pd.read_csv(cxr_split_file, index_col=1)
        self.chexpert_df=pd.read_csv(cxr_chexpert_file)

    def process_patient(self, bp):
        patient_id=bp.split('/patient_')[-1].split('/')[0]      
        if patient_id in self.remove_list:
            return None
        study_id=bp.split('/')[-1]
        dicom_id=self.meta_df.loc[int(patient_id)]['dicom_id']
        split=self.cxr_split.loc[dicom_id]['split']
        EG=self.meta_df.loc[int(patient_id)]['in_eye_gaze']
        REFLACX=self.meta_df.loc[int(patient_id)]['in_reflacx']
        
        try:
            findings=self.cxr_reports.loc[study_id]['findings']
            findings = re.sub("\s+", " ", findings)
            impression=self.cxr_reports.loc[study_id]['impression']
            impression = re.sub("\s+", " ", impression)
            if findings and impression:
                image_path=None
        except:
            return None
        image_path=os.path.join(self.mimic_eye_path,f'/patient_{patient_id}/CXR-JPG/{study_id}/{dicom_id}.jpg')
        image = Image.open(image_path)
        width, height=image.size
        if width>height:
            newheight=int(float(height)/float(width)*512.0)
            image512 = image.resize((512, newheight))
        else:
            newwidth=int(float(width)/float(height)*512.0)
            image512 = image.resize((newwidth, 512))
        resized_image_path=image_path.replace('.jpg','_512.png')
        image512.save(resized_image_path)
        width512, height512=image512.size
        image.close()
        
        if EG:
            gaze_data=pd.read_csv(os.path.join(self.mimic_eye_path,f'/patient_{patient_id}/EyeGaze/fixations.csv'))
            initial_value=gaze_data['Time (in secs)'].iloc[0]
            gaze_data['timestamp_start_fixation']=gaze_data['Time (in secs)'].shift().fillna(0.0)
            gaze_data['timestamp_end_fixation']=gaze_data['Time (in secs)']
            gaze_data['Time (in secs)']=gaze_data['Time (in secs)'].diff().fillna(initial_value)
            gaze_data=gaze_data[(gaze_data['X_ORIGINAL']>0)&(gaze_data['Y_ORIGINAL']>0)&(gaze_data['X_ORIGINAL']<width)&(gaze_data['Y_ORIGINAL']<height)]
            gaze_data=gaze_data[['Time (in secs)','X_ORIGINAL', 'Y_ORIGINAL','transcript']]
            full_heatmap_image,full_frames=generate_heatmap(image512.convert('RGBA'), width512, height512, width512/width, height512/height, gaze_data, _x='X_ORIGINAL', _y='Y_ORIGINAL')
            mean_time_diff = gaze_data['Time (in secs)'].mean()
            gaze_data = gaze_data[gaze_data['Time (in secs)'] > mean_time_diff].dropna()
            
            text=process_f(gaze_data, width, height, mode='EG')
            heatmap_image,frames=generate_heatmap(image512.convert('RGBA'), width512, height512, width512/width, height512/height, gaze_data, _x='X_ORIGINAL', _y='Y_ORIGINAL')
    
        else:
            for i in os.listdir(os.path.join(self.mimic_eye_path,f'patient_{patient_id}/REFLACX/main_data/')):
                try:
                    gaze_data=pd.read_csv(os.path.join(self.mimic_eye_path,f'patient_{patient_id}/REFLACX/main_data/{i}/fixations.csv'))
                except:
                    gaze_data=None
                if gaze_data is not None:
                    break
            gaze_data=gaze_data[(gaze_data['x_position']>0)&(gaze_data['y_position']>0)&(gaze_data['x_position']<width)&(gaze_data['y_position']<height)]
            gaze_data['Time (in secs)']=gaze_data['timestamp_end_fixation']-gaze_data['timestamp_start_fixation']
            gaze_data=gaze_data[['Time (in secs)','x_position', 'y_position','transcript']]
            full_heatmap_image,full_frames=generate_heatmap(image512.convert('RGBA'), width512, height512, width512/width, height512/height, gaze_data, _x='x_position', _y='y_position')
            mean_time_diff = gaze_data['Time (in secs)'].mean()
            gaze_data = gaze_data[gaze_data['Time (in secs)'] > mean_time_diff].dropna()
            
            text=process_f(gaze_data, width, height, mode='REFLACX')
            heatmap_image,frames=generate_heatmap(image512.convert('RGBA'), width512, height512, width512/width, height512/height, gaze_data, _x='x_position', _y='y_position')
            
        result_image = Image.alpha_composite(image512.convert('RGBA'), heatmap_image)
        heatmap_image_path=image_path.replace('.jpg','_heatmap.png')
        full_result_image = Image.alpha_composite(image512.convert('RGBA'), full_heatmap_image)
        full_heatmap_image_path=image_path.replace('.jpg','_fullheatmap.png')
        result_image.save(heatmap_image_path)
        full_result_image.save(full_heatmap_image_path)
        image512.close()
        result_image.close()
        full_result_image.close()
        video_path=image_path.replace('.jpg','_eyefixation.mp4')
        with imageio.get_writer(video_path, fps=10) as video_writer:
            for frame in frames:
                video_writer.append_data(frame)
    
        full_video_path=image_path.replace('.jpg','_fulleyefixation.mp4')
        with imageio.get_writer(full_video_path, fps=10) as video_writer:
            for frame in full_frames:
                video_writer.append_data(frame)
                
        target=chexpert_df[self.chexpert_df.study_id==study_id.lstrip('s')]
        diseases=target[target>0].dropna(axis=1).drop(['subject_id','study_id'], axis=1).columns.tolist()
        ddx="\n".join(diseases)
    
        temp_dict={}
        temp_dict['image_id']=resized_image_path.replace(self.mimic_eye_path, '')
        temp_dict['heatmap_image_id']=heatmap_image_path.replace(self.mimic_eye_path, '')
        temp_dict['full_heatmap_image_id']=full_heatmap_image_path.replace(self.mimic_eye_path, '')
        temp_dict['video_id']=video_path.replace(self.mimic_eye_path, '')
        temp_dict['full_video_id']=full_video_path.replace(self.mimic_eye_path, '')
        temp_dict['fixation_text']=text
        temp_dict['findings']=findings
        temp_dict['impression']=impression
        temp_dict['differential_diagnosis']=ddx
        temp_dict['split']=split
        temp_dict['source']='EG' if EG else 'REFLACX'
        return temp_dict

    def process_all(self):
        data_dict=[]
        train_data_dict=[]
        val_test_data_dict=[]
        
        for bp in tqdm(self.patient_subjects, total=len(self.patient_subjects)):
            temp_dict=self.process_patient(bp)
            if temp_dict is None:
                continue
            data_dict.append(temp_dict)
            if temp_dict['split']=='train':
                train_data_dict.append(temp_dict)
            else:
                val_test_data_dict.append(temp_dict)
        
        with open('mimic-eye-video-alpha.json', 'w') as fi:
            fi.write(json.dumps(train_data_dict))
        
        with open('mimic-eye-video-beta.json', 'w') as fi:
            fi.write(json.dumps(val_test_data_dict))

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--mimic-eye-path", type=str, default='physionet.org/files/mimic-eye')
    parser.add_argument("--mimic-cxr-path", type=str, default='physionet.org/files/mimic-cxr/2.0.0')
    args = parser.parse_args()

    gvg=GazeVideoGenerator(mimic_eye_path=args.mimic_eye_path,
                           mimic_cxr_path=args.mimic_cxr_path
                          )
    gvg.process_all()