import os
import numpy as np
import pandas as pd
import pydicom
import SimpleITK as sitk
import cv2
from ... import const

pd.options.mode.chained_assignment = None

KOA_RAW_DATA_DIR = const.KOA_RAW_DATA_DIR
XRAY_JSW_PATH = const.XRAY_JSW_PATH
XRAY_SQ_PATH = const.XRAY_SQ_PATH



def dicom_dataset_to_dict(dicom_header):
    '''
    Auxilary method to convert dicom header to python dictionary
    Based off of: https://github.com/cxr-eye-gaze/eye-gaze-dataset
    :param dicom_header: input dicome header
    :return: python dictionary of dicom header
    '''

    dicom_dict = {}
    repr(dicom_header)
    for dicom_value in dicom_header.values():
        if dicom_value.tag == (0x7fe0, 0x0010):
            continue
        if type(dicom_value.value) == pydicom.dataset.Dataset:
            dicom_dict[dicom_value.tag] = dicom_dataset_to_dict(dicom_value.value)
        else:
            if type(dicom_value.value) == pydicom.sequence.Sequence:
                for value in dicom_value.value:
                    for m in value:
                        if m.name != 'LUT Data':
                            dicom_dict[m.name] = _convert_value(m.value)
            else:
                v = _convert_value(dicom_value.value)
                dicom_dict[dicom_value.name] = v
    return dicom_dict



def _convert_value(v):
    '''
    Converts DICOM values
    Based off of: https://github.com/cxr-eye-gaze/eye-gaze-dataset
    '''

    t = type(v)
    if t in (list, int, float):
        cv = v
    elif t == str:
        cv = v.replace(u"\u0000", "").strip()
    elif t == bytes:
        s = v.decode('utf-8', 'ignore')
        cv = s.replace(u"\u0000", "").strip()
    elif t == pydicom.valuerep.DSfloat:
        cv = float(v)
    elif t == pydicom.valuerep.IS:
        cv = int(v)
    elif t == pydicom.valuerep.PersonName:
        cv = str(v)
    else:
        cv = repr(v)
    return cv



def apply_windowing(image, info):
    '''
    Auxilary method to apply windowing
    Based off of: https://github.com/cxr-eye-gaze/eye-gaze-dataset
    :param image: input image
    :param info: dicom info related to DICOM
    :return: image with windowing applied
    '''

    if 'Window Center' not in info:
        image = cv2.normalize(image, dst=np.array([]), alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
    else:
        wc = info['Window Center']
        ww = info['Window Width']
        if isinstance(wc, str):
            wc = float(eval(wc)[0])
            ww = float(eval(ww)[0])
        if type(wc) != float:
            wc = wc.replace(" ", "")
            wc = wc.replace("'", "")
            wc = wc.replace("[", "")
            wc = wc.replace("]", "")
            wc = wc.split(',')
            ww = ww.replace(" ", "")
            ww = ww.replace("'", "")
            ww = ww.replace("[", "")
            ww = ww.replace("]", "")
            ww = ww.split(',')
            wc = wc[0]
            ww = ww[0]
        wc = float(wc)
        ww = float(ww)
        wl = wc - ww * 0.5
        if wl < 0:
            wl = 0
        wu = wc + ww * 0.5
        image = sitk.GetImageFromArray(image)
        image = sitk.IntensityWindowing(image, wl, wu)
        image = np.asarray(sitk.GetArrayFromImage(image), np.uint16)
    return image



def resize_pad(image, height=1200, width=600):
    '''
    Resizing and repadding of image to spefic image size by keeping aspect ratio
    Based off of: https://github.com/cxr-eye-gaze/eye-gaze-dataset
    :param image: input image to resize and pad
    :param height: new height of the image
    :param width: new width of the image
    :return: resized image
    '''
    old_size = image.shape[:2]  # old_size is in (height, width) format

    ratio = float(width) / max(old_size)
    new_size = tuple([int(x * ratio) for x in old_size])

    im = cv2.resize(image, (new_size[1], new_size[0]),interpolation=cv2.INTER_NEAREST)

    delta_w = height - new_size[1]
    delta_h = width - new_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = [0, 0, 0]
    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,value=color)
    return new_im



def dicomToPng(dcm_name, dcm_dir, new_dir):
    '''
    Transforms MIMIC-CXR DICOM files to PNGs, stores files, and creates additional csv.
    :param eye_gaze_csv: master_sheet.csv as numpy array
    :param dcm_dir: directory containing MIMIC-CXR DICOM images
    :param new_dir: directory where you want to store MIMIC-CXR PNG images
    '''

    successful = 1
    try:
        dicom_img = pydicom.dcmread(os.path.join(dcm_dir, dcm_name))
        
        dictionary = dicom_dataset_to_dict(dicom_img)
        png_img = dicom_img.pixel_array.copy().astype(np.uint16)
        png_img = apply_windowing(png_img, dictionary)
        png_img = resize_pad(png_img)

        # Remove black borders
        y_nonzero, x_nonzero = np.nonzero(png_img)
        png_img = png_img[np.min(y_nonzero):np.max(y_nonzero), np.min(x_nonzero):np.max(x_nonzero)]

        center = png_img.shape

        x = center[1] // 2 - 512 // 2
        y = center[0] // 2 - 256 // 2

        png_img = png_img[y:y+256, x:x+512]

        img_width = png_img.shape[1] // 2
        right_knee = png_img[:, :img_width]
        left_knee = png_img[:, img_width:]

        if right_knee.shape[0] != 256 and right_knee.shape[1] != 256:
            print(f"Wrong shape right knee {right_knee.shape}")
        if left_knee.shape[0] != 256 and left_knee.shape[1] != 256:
            print(f"Wrong shape right knee {left_knee.shape}")

        cv2.imwrite(os.path.join(new_dir, dcm_name + '_full.png'), png_img.astype(np.uint8))
        cv2.imwrite(os.path.join(new_dir, dcm_name + '_left.png'), left_knee.astype(np.uint8))
        cv2.imwrite(os.path.join(new_dir, dcm_name + '_right.png'), right_knee.astype(np.uint8))
    except:
        print("Issue processing DICOM... will ignore sample")
        successful = 0

    return successful



def koa_dicom_to_png():
    '''
    Convert all dicoms to pngs. Also create a CSV file that contains all of
    the necessary data to generate the koa datasets
    '''

    # Make sure necessary directories exist
    if not os.path.exists(KOA_RAW_DATA_DIR):
        raise FileNotFoundError('Directory ' + KOA_RAW_DATA_DIR + ' does not exist.')
    dicom_dir = os.path.join(KOA_RAW_DATA_DIR, 'dicoms')
    if not os.path.exists(dicom_dir):
        raise FileNotFoundError('Directory ' + dicom_dir + ' does not exist.')
    png_dir = os.path.join(KOA_RAW_DATA_DIR, 'pngs')
    if not os.path.exists(png_dir):
        os.mkdir(png_dir)

    # Load CSV and get necessary columns
    csv = pd.read_csv(XRAY_JSW_PATH, delimiter='|')
    columns = ['ID',
               'V01BARCDJD',
               'SIDE',
               'V01JSW150',
               'V01JSW175',
               'V01JSW200',
               'V01JSW225',
               'V01JSW250',
               'V01JSW275',
               'V01JSW300',
               'V01LJSW700',
               'V01LJSW725',
               'V01LJSW750',
               'V01LJSW775',
               'V01LJSW800',
               'V01LJSW825',
               'V01LJSW850',
               'V01LJSW875',
               'V01LJSW900'
               ]
    csv = csv[columns]
    
    # Create sub-directory for each patient if it doesn't already exist
    id_df = csv['ID']
    for id in id_df:
        id_png_path = os.path.join(png_dir, str(id))
        if not os.path.exists(id_png_path):
            os.mkdir(id_png_path)

    # Process DICOM images to PNG images
    count = 0
    for dicom in os.listdir(dicom_dir):
        id = csv.loc[csv['V01BARCDJD'] == int(dicom)]['ID'].reset_index()
        id = str(id.iloc[0].values[1])
        new_dir = os.path.join(png_dir, id)
        successful = dicomToPng(dicom, dicom_dir, new_dir)
        count += successful
    print(f"Successfully converted {count} dicoms to png images")
    
    # Get KLG ratings
    csv_klg = pd.read_csv(XRAY_SQ_PATH, delimiter='|')
    csv['KLG'] = 5
    for index, row in csv_klg.iterrows():
        try:
            csv.loc[(csv['V01BARCDJD'] == int(row['V01BARCDBU'])) & (csv['SIDE'] == row['SIDE']), 'KLG'] = row['V01XRKL']
        except:
            print("Ran into error -- going to ignore it")


    # Get KLG ratings
    csv_klg = pd.read_csv(XRAY_SQ_PATH, delimiter='|')
    csv['V01XROSFL'] = 5
    csv['V01XROSFM'] = 5
    csv['V01XROSTL'] = 5
    csv['V01XROSTM'] = 5
    for index, row in csv_klg.iterrows():
        try:
            if pd.isna(row['V01XROSFL']):
                row['V01XROSFL'] = 0
            if pd.isna(row['V01XROSFM']):
                row['V01XROSFM'] = 0
            if pd.isna(row['V01XROSTL']):
                row['V01XROSTL'] = 0
            if pd.isna(row['V01XROSTM']):
                row['V01XROSTM'] = 0

            csv.loc[(csv['V01BARCDJD'] == int(row['V01BARCDBU'])) & (csv['SIDE'] == row['SIDE']), 'V01XROSFL'] = row['V01XROSFL']
            csv.loc[(csv['V01BARCDJD'] == int(row['V01BARCDBU'])) & (csv['SIDE'] == row['SIDE']), 'V01XROSFM'] = row['V01XROSFM']
            csv.loc[(csv['V01BARCDJD'] == int(row['V01BARCDBU'])) & (csv['SIDE'] == row['SIDE']), 'V01XROSTL'] = row['V01XROSTL']
            csv.loc[(csv['V01BARCDJD'] == int(row['V01BARCDBU'])) & (csv['SIDE'] == row['SIDE']), 'V01XROSTM'] = row['V01XROSTM']

            
        except:
            print("Ran into error -- going to ignore it")


    # Remove any rows with NaN values
    csv = csv.dropna()

    # Drop barcode and side columns and add file location column
    file_locations = []
    for index, row in csv.iterrows():
        barcode = '0' + str(int(row['V01BARCDJD']))
        id = str(int(row['ID']))
        side = 'right' if (row['SIDE'] == 1) else 'left'
        filename = barcode + '_' + side + '.png'
        file_location = os.path.join(png_dir, id, filename)
        file_locations.append(file_location)
    csv = csv.drop(columns=['V01BARCDJD', 'SIDE'])
    csv['LOCATION'] = file_locations

    # Get basic dataset info
    print(f'Total dataset size is:  {csv.shape[0]}')
    print(f'Number of samples with KLG = 0: {csv.loc[csv["KLG"] == 0].shape[0]}')
    print(f'Number of samples with KLG = 1: {csv.loc[csv["KLG"] == 1].shape[0]}')
    print(f'Number of samples with KLG = 2: {csv.loc[csv["KLG"] == 2].shape[0]}')
    print(f'Number of samples with KLG = 3: {csv.loc[csv["KLG"] == 3].shape[0]}')
    print(f'Number of samples with KLG = 4: {csv.loc[csv["KLG"] == 4].shape[0]}')

    # Save the csv file
    csv.to_csv(os.path.join(KOA_RAW_DATA_DIR, 'full_koa_dataset.csv'))



if __name__ == '__main__':
    koa_dicom_to_png()