import os
from enum import Enum
import numpy as np
import pandas as pd
import PIL
import torch
from torchvision import transforms
import random
random.seed(0)
from constants import *

_CLASSNAMES = [
    "fundus",
]

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class DatasetSplit(Enum):
    TRAIN = "train"
    VAL = "val"
    JSIEC = "JSIEC"
    RIADD = "RIADD"


class FundusDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for Fundus.
    """

    def __init__(
        self,
        source,
        classname,
        resize=224,
        imagesize=224,
        split=DatasetSplit.TRAIN,
        train_val_split=1.0,
        rotate_degrees=0,
        translate=0,
        brightness_factor=0,
        contrast_factor=0,
        saturation_factor=0,
        gray_p=0,
        h_flip_p=0,
        v_flip_p=0,
        scale=0,
        train_scale=1,
        category='normal',
        **kwargs,
    ):
        super().__init__()
        self.source = source
        self.split = split
        self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES
        self.train_val_split = train_val_split
        self.transform_std = IMAGENET_STD
        self.transform_mean = IMAGENET_MEAN
        self.eddfs_image_dir = '/root/raw_dataset/EDDFS/PreprocessedImages'
        self.brset_image_dir = '/root/raw_dataset/brazilian-ophthalmological/1.0.1/fundus_photos_preprocessed'
        self.train_scale = train_scale

        if self.split == DatasetSplit.TRAIN:
            csv_eddfs_df = pd.read_csv(os.path.join(EDDFS_tgt_root, f'train_for_{train_scale}.csv'))
            if category == 'normal':
                csv_eddfs_df = csv_eddfs_df[csv_eddfs_df['abnormal'] == 0]
            elif category == 'abnormal':
                csv_eddfs_df = csv_eddfs_df[csv_eddfs_df['abnormal'] == 1]
            else:
                raise ValueError(f"Invalid category: {category}")
            image_ids = list(csv_eddfs_df['fnames'])
            img_paths_eddfs = [os.path.join(self.eddfs_image_dir, image_id) for image_id in image_ids]

            csv_brset_df = pd.read_csv(os.path.join(BRSET_tgt_root, f'train_for_{train_scale}.csv'))
            if category == 'normal':
                csv_brset_df = csv_brset_df[csv_brset_df['abnormal'] == 0]
            elif category == 'abnormal':
                csv_brset_df = csv_brset_df[csv_brset_df['abnormal'] == 1]
            else:
                raise ValueError(f"Invalid category: {category}")
            image_ids = list(csv_brset_df['image_id'])
            img_paths_brset = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids]

            self.img_paths = img_paths_eddfs + img_paths_brset
            
            self.targets = np.zeros(len(self.img_paths))
        
        elif self.split == DatasetSplit.REMAIN_5000:
            csv_eddfs_df = pd.read_csv(os.path.join(EDDFS_tgt_root, 'train_5000_remain.csv'))
            csv_eddfs_df_normal = csv_eddfs_df[csv_eddfs_df['abnormal'] == 0]
            csv_eddfs_df_abnormal = csv_eddfs_df[csv_eddfs_df['abnormal'] == 1]
            image_ids_normal = list(csv_eddfs_df_normal['fnames'])
            img_paths_eddfs_normal = [os.path.join(self.eddfs_image_dir, image_id) for image_id in image_ids_normal]
            image_ids_abnormal = list(csv_eddfs_df_abnormal['fnames'])
            img_paths_eddfs_abnormal = [os.path.join(self.eddfs_image_dir, image_id) for image_id in image_ids_abnormal]

            csv_brset_df = pd.read_csv(os.path.join(BRSET_tgt_root, 'train_5000_remain.csv'))
            csv_brset_df_normal = csv_brset_df[csv_brset_df['abnormal'] == 0]
            csv_brset_df_abnormal = csv_brset_df[csv_brset_df['abnormal'] == 1]
            image_ids_normal = list(csv_brset_df_normal['image_id'])
            img_paths_brset_normal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_normal]
            image_ids_abnormal = list(csv_brset_df_abnormal['image_id'])
            img_paths_brset_abnormal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_abnormal]

            self.img_paths = img_paths_eddfs_normal + img_paths_brset_normal + img_paths_eddfs_abnormal + img_paths_brset_abnormal
            self.targets = np.concatenate((np.zeros(len(img_paths_eddfs_normal) + len(img_paths_brset_normal)), np.ones(len(img_paths_eddfs_abnormal) + len(img_paths_brset_abnormal))))

        elif self.split == DatasetSplit.VAL:
            csv_eddfs_df = pd.read_csv(os.path.join(EDDFS_tgt_root, 'val.csv'))
            csv_eddfs_df_normal = csv_eddfs_df[csv_eddfs_df['abnormal'] == 0]
            csv_eddfs_df_abnormal = csv_eddfs_df[csv_eddfs_df['abnormal'] == 1]
            image_ids_normal = list(csv_eddfs_df_normal['fnames'])
            img_paths_eddfs_normal = [os.path.join(self.eddfs_image_dir, image_id) for image_id in image_ids_normal]
            image_ids_abnormal = list(csv_eddfs_df_abnormal['fnames'])
            img_paths_eddfs_abnormal = [os.path.join(self.eddfs_image_dir, image_id) for image_id in image_ids_abnormal]

            csv_brset_df = pd.read_csv(os.path.join(BRSET_tgt_root, 'val.csv'))
            csv_brset_df_normal = csv_brset_df[csv_brset_df['abnormal'] == 0]
            csv_brset_df_abnormal = csv_brset_df[csv_brset_df['abnormal'] == 1]
            image_ids_normal = list(csv_brset_df_normal['image_id'])
            img_paths_brset_normal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_normal]
            image_ids_abnormal = list(csv_brset_df_abnormal['image_id'])
            img_paths_brset_abnormal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_abnormal]

            csv_brset_test_df = pd.read_csv(os.path.join(BRSET_tgt_root, 'test.csv'))
            csv_brset_test_df_normal = csv_brset_test_df[csv_brset_test_df['abnormal'] == 0]
            csv_brset_test_df_abnormal = csv_brset_test_df[csv_brset_test_df['abnormal'] == 1]
            image_ids_normal = list(csv_brset_test_df_normal['image_id'])
            img_paths_brset_test_normal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_normal]
            image_ids_abnormal = list(csv_brset_test_df_abnormal['image_id'])
            img_paths_brset_test_abnormal = [os.path.join(self.brset_image_dir, image_id + '.jpg') for image_id in image_ids_abnormal]

            self.img_paths = img_paths_eddfs_normal + img_paths_brset_normal + img_paths_brset_test_normal + img_paths_eddfs_abnormal + img_paths_brset_abnormal + img_paths_brset_test_abnormal
            self.targets = np.concatenate((np.zeros(len(img_paths_eddfs_normal) + len(img_paths_brset_normal) + len(img_paths_brset_test_normal)), np.ones(len(img_paths_eddfs_abnormal) + len(img_paths_brset_abnormal) + len(img_paths_brset_test_abnormal))))
        
        elif self.split == DatasetSplit.RIADD:
            self.image_dir = '/root/raw_dataset/RIADD'
            df = pd.read_csv('/root/dataset/fundus/RIADD/test.csv')
            csv_normal_df = df[df['abnormal'] == 0]
            csv_abnormal_df = df[df['abnormal'] == 1]

            image_ids = list(csv_normal_df['ID'])
            image_dirs = list(csv_normal_df['dir']) # 
            img_paths_normal = [os.path.join(self.image_dir, image_dir, str(image_id)+'.png') for image_dir, image_id in zip(image_dirs, image_ids)]

            image_ids = list(csv_abnormal_df['ID'])
            image_dirs = list(csv_abnormal_df['dir'])
            img_paths_abnormal = [os.path.join(self.image_dir, image_dir, str(image_id)+'.png') for image_dir, image_id in zip(image_dirs, image_ids)]
            self.img_paths = img_paths_normal + img_paths_abnormal
            self.targets = np.array(len(img_paths_normal)*[0] + len(img_paths_abnormal)*[1])
        
        elif self.split == DatasetSplit.JSIEC:
            self.image_dir = '/root/raw_dataset/JSIEC/1000images'
            csv_df = pd.read_csv('/root/dataset/fundus/JSIEC/test.csv')
            csv_normal_df = csv_df[csv_df['abnormal'] == 0]
            csv_abnormal_df = csv_df[csv_df['abnormal'] == 1]
            image_ids = list(csv_normal_df['fnames'])
            image_dirs = list(csv_normal_df['dirs'])
            img_paths_normal = [os.path.join(self.image_dir, image_dir, image_id) for image_dir, image_id in zip(image_dirs, image_ids)]

            image_ids = list(csv_abnormal_df['fnames'])
            image_dirs = list(csv_abnormal_df['dirs'])
            img_paths_abnormal = [os.path.join(self.image_dir, image_dir, image_id) for image_dir, image_id in zip(image_dirs, image_ids)]
            self.img_paths = img_paths_normal + img_paths_abnormal
            self.targets = np.array(len(img_paths_normal)*[0] + len(img_paths_abnormal)*[1])
        else:
            raise ValueError(f"Invalid split: {self.split}")
        

        self.labels = ["good" if x == 0 else "bad" for x in self.targets]
        print(len(self.labels))
        self.data_to_iterate = list(zip(self.img_paths, self.labels))

        self.transform_img = [
            transforms.Resize(resize),
            transforms.CenterCrop(imagesize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
        self.transform_img = transforms.Compose(self.transform_img)

        self.imagesize = (3, imagesize, imagesize)

    def __getitem__(self, idx):
        classname = 'fundus'
        anomaly = 'good' if self.targets[idx] == 0 else 'bad'
        image_path = self.img_paths[idx]

        image = PIL.Image.open(image_path).convert("RGB")
        image = self.transform_img(image)

        mask = torch.ones([1, *image.size()[1:]])

        return {
            "image": image,
            "mask": mask,
            "classname": classname,
            "anomaly": anomaly,
            "is_anomaly": int(anomaly != "good"),
            "image_name": image_path.split("/")[-1],
            "image_path": image_path,
        }

    def __len__(self):
        return len(self.img_paths)