import os
import random
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torchvision import transforms


class MMNestedDataset(torch.utils.data.Dataset):
    def __init__(self, patient_csv, lesion_csv, patient_target_csv, lesion_target_csv, image_dir, 
                 transform=None, N=None, val_split=0.1, split='all', random_state=42):
        """
        Initializes the custom dataset.
        Args:
        - patient_csv: Path to the patient data CSV file.
        - lesion_csv: Path to the lesion data CSV file.
        - patient_target_csv: Path to the patient target CSV file.
        - lesion_target_csv: Path to the lesion target CSV file.
        - image_dir: Directory where lesion images are stored.
        - transform: Any transformations to apply to the images (e.g., resizing, normalization).
        - N: Max number of lesions per patient.
        - val_split: Proportion of patients used for validation (only relevant if split is 'train' or 'val').
        - split: One of {'train', 'val', 'all'}.
        """
        assert split in ['train', 'val', 'all'], "split must be one of 'train', 'val', 'all'"

        # Load patient, lesion, and target data
        self.patient_data = pd.read_csv(patient_csv)
        self.lesion_data = pd.read_csv(lesion_csv)
        self.patient_target_data = pd.read_csv(patient_target_csv)
        self.lesion_target_data = pd.read_csv(lesion_target_csv)

        # Set patient_id as the index for easier access
        self.patient_data.set_index('patient_id', inplace=True)
        self.patient_target_data.set_index('patient_id', inplace=True)

        # Store image directory and transformations
        self.image_dir = image_dir
        self.transform = transform
        if self.transform is None:
            self.transform = transforms.ToTensor()          
        self.N = N  # Maximum number of lesions to sample per patient
        self.split = split
        self.val_split = val_split

        # Create a mapping from patient_id to lesion data
        self.patient_to_lesions = self.lesion_data.groupby('patient_id')

        # Handle train/val splitting
        if split in ['train', 'val'] and val_split > 0:
            pos_ids = self.get_patient_ids_by_risk(1)
            neg_ids = self.get_patient_ids_by_risk(0)

            train_pos, val_pos = train_test_split(pos_ids, test_size=val_split, random_state=random_state)
            train_neg, val_neg = train_test_split(neg_ids, test_size=val_split, random_state=random_state)

            if split == 'train':
                selected_ids = train_pos + train_neg
            else:  # split == 'val'
                selected_ids = val_pos + val_neg

            random.shuffle(selected_ids)
            self.patient_data = self.patient_data.loc[selected_ids]
            self.patient_target_data = self.patient_target_data.loc[selected_ids]
        # else, if split == 'all', keep all patients

        self.patient_ids = list(self.patient_data.index)

    def __len__(self):
        return len(self.patient_data)
    
    def _get_risk_column_name(self):
        """
        Determines whether 'risk_id' or 'risk' is present in the patient_target_data.
        Returns the name of the column to use for patient targets.
        """
        if 'risk_id' in self.patient_target_data.columns:
            return 'risk_id'
        elif 'risk' in self.patient_target_data.columns:
            return 'risk'
        else:
            raise ValueError("Neither 'risk_id' nor 'risk' column found in patient_target_data.")
    
    def get_patient_ids_by_risk(self, risk_value):
        """
        Get a list of patient_ids by risk value (1 for positive, 0 for negative).
        This method will filter the patient_target_data DataFrame to return the
        patient IDs that match the specified risk value.
        """
        risk_column = self._get_risk_column_name()
        return self.patient_target_data[self.patient_target_data[risk_column] == risk_value].index.tolist()

    def __getitem__(self, idx):
        """
        Retrieve a single patient sample and corresponding lesion data.
        Returns:
        A tuple of tuples (good luck): 
        ((patient_tabular_data, patient_target), ([lesion_tabular_data], [lesion_image_data], [lesion_target])).
        """
        # Get the patient_id
        patient_id = self.patient_ids[idx]

        # Get the patient-level tabular data
        patient_tabular = torch.tensor(self.patient_data.loc[patient_id].drop('patient_id', errors='ignore').values, dtype=torch.float32)

        risk_column = self._get_risk_column_name()
        patient_target = torch.tensor([self.patient_target_data.loc[patient_id][risk_column]], dtype=torch.float32)

        # Get the corresponding lesion data for the patient
        lesions = self.patient_to_lesions.get_group(patient_id)

        # Extract positive-negative lesions per patient
        lesion_ids = lesions['lesion_id'].tolist()
        lesion_target_data_filtered = self.lesion_target_data[self.lesion_target_data['lesion_id'].isin(lesion_ids)]
        positive_lesions = lesion_target_data_filtered[lesion_target_data_filtered['malignant'] == 1.0]
        positive_lesions_ids = positive_lesions['lesion_id'].tolist()
        positive_count = len(positive_lesions)

        # Optionally limit the number of lesions to N
        if self.N and len(lesions) > self.N:
            if positive_count > 0:
                lesions_pos = lesions[lesions['lesion_id'].isin(positive_lesions_ids)]
                lesions_neg = lesions[~lesions['lesion_id'].isin(positive_lesions_ids)]
                if positive_count >= self.N:
                    lesions = lesions_pos.sample(n=self.N)
                else:
                    lesions_neg = lesions_neg.sample(n=self.N-positive_count)
                    lesions = pd.concat([lesions_pos, lesions_neg])
            else:
                lesions = lesions.sample(n=self.N)

        # Lists to store lesion data
        lesion_tabular_list = []
        lesion_image_list = []
        lesion_target_list = []

        for _, lesion_row in lesions.iterrows():
            # Extract lesion tabular data
            lesion_data = lesion_row.drop(['lesion_id', 'patient_id'], errors='ignore')
            lesion_data_numeric = pd.to_numeric(lesion_data, errors='coerce').fillna(0)
            lesion_tabular = torch.tensor(lesion_data_numeric.values, dtype=torch.float32)
            lesion_tabular_list.append(lesion_tabular)

            isic_id = lesion_row['lesion_id']
            # Load image (supporting .jpg, .png, .bmp)
            image_path = None
            for ext in ['.jpg', '.png', '.bmp']:
                candidate = os.path.join(self.image_dir, f"{isic_id}{ext}")
                if os.path.exists(candidate):
                    image_path = candidate
                    break
            if image_path is None:
                raise FileNotFoundError(f"No image found for lesion ID {isic_id} with .jpg or .png extension")
            
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
            lesion_image_list.append(image)

            # Get lesion target
            lesion_target = self.lesion_target_data[self.lesion_target_data['lesion_id'] == isic_id]['malignant'].values
            lesion_target_value = lesion_target[0] if len(lesion_target) > 0 else 0
            lesion_target_list.append(torch.tensor([lesion_target_value], dtype=torch.float32))

        # Return the individual sample (without batching)
        lesion_tabular_tensor = torch.stack(lesion_tabular_list)
        lesion_image_tensor = torch.stack(lesion_image_list)
        lesion_target_tensor = torch.stack(lesion_target_list)

        return (patient_tabular.unsqueeze(0), patient_target), (lesion_tabular_tensor, lesion_image_tensor, lesion_target_tensor)
    
def mmnested_collate_fn(batch):
    patient_data, patient_targets = [], []
    lesion_tabular_data, lesion_image_data, lesion_targets = [], [], []

    for (p_data, p_target), (l_tabular, l_image, l_target) in batch:
        patient_data.append(p_data)            # [1, patient_features]
        patient_targets.append(p_target)       # [1]
        lesion_tabular_data.append(l_tabular)  # [N_i, tabular_features]
        lesion_image_data.append(l_image)      # [N_i, C, H, W]
        lesion_targets.append(l_target)        # [N_i, 1]

    # Stack patient-level data (fixed shape)
    patient_data = torch.cat(patient_data, dim=0)        # [B, patient_features]
    patient_targets = torch.cat(patient_targets, dim=0)  # [B]

    return (patient_data, patient_targets), (lesion_tabular_data, lesion_image_data, lesion_targets)

class MMFlattenDataset(torch.utils.data.Dataset):
    def __init__(self, patient_csv, lesion_csv, patient_target_csv, lesion_target_csv, image_dir,
                 transform=None, inner_only=False, image_only=False, 
                 val_split=0.1, split='all', random_state=42, stratify=False):
        
        assert split in ['train', 'val', 'all'], "split must be one of 'train', 'val', 'all'"
        self.image_dir = image_dir
        self.transform = transform
        if self.transform is None:
            self.transform = transforms.ToTensor() 
        self.inner_only = inner_only
        self.image_only = image_only 

        # Load lesion-level data
        self.lesion_data = pd.read_csv(lesion_csv)
        self.lesion_target_data = pd.read_csv(lesion_target_csv)
        self.lesion_data.set_index('lesion_id', inplace=True)
        self.lesion_target_data.set_index('lesion_id', inplace=True)

        # Load patient-level data
        self.has_patient_data = patient_csv is not None and patient_target_csv is not None
        if self.has_patient_data:
            self.patient_data = pd.read_csv(patient_csv)
            self.patient_target_data = pd.read_csv(patient_target_csv)
            self.patient_data.set_index('patient_id', inplace=True)
            self.patient_target_data.set_index('patient_id', inplace=True)

        if split in ['train', 'val'] and val_split > 0:
            # Use malignancy label for stratification
            lesion_ids = self.lesion_data.index.tolist()
            indices = list(range(len(lesion_ids)))
            if stratify:
                targets = self.lesion_target_data.loc[lesion_ids]['malignant'].values
                train_indices, val_indices = train_test_split(
                    indices,
                    test_size=val_split,
                    stratify=targets,
                    random_state=random_state)
            else:
                train_indices, val_indices = train_test_split(
                    indices, test_size=val_split, random_state=random_state)
            
            selected_indices = train_indices if split == 'train' else val_indices
            self.lesion_ids = [lesion_ids[i] for i in selected_indices]
        else:
            self.lesion_ids = list(self.lesion_data.index)

    def __len__(self):
        return len(self.lesion_ids)
    
    def __getitem__(self, idx):
        """
        Retrieve a single lesion sample along with corresponding patient data.
        Returns: A tuple (lesion_tabular, patient_tabular[Optional], image, lesion_target).
        """
        # Get the patient_id
        lesion_id = self.lesion_ids[idx]
        lesion_row = self.lesion_data.loc[lesion_id]
        
        # Get the lesion-level tabular data
        lesion_tabular = torch.tensor(pd.to_numeric(lesion_row.drop('patient_id', errors='ignore').values), dtype=torch.float32)
        
        # Get the patient-level tabular data (if available)
        patient_tabular = None
        if self.has_patient_data:
            patient_id = lesion_row['patient_id']
            patient_tabular = torch.tensor(self.patient_data.loc[patient_id].values, dtype=torch.float32)
        
        # Load image (supporting .jpg, .png, .bmp)
        for ext in ['.jpg', '.png', '.bmp']:
            image_path = os.path.join(self.image_dir, f"{lesion_id}{ext}")
            if os.path.exists(image_path):
                break
        else:
            raise FileNotFoundError(f"No image found for lesion ID {lesion_id}")

        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        lesion_target=torch.tensor(pd.to_numeric(self.lesion_target_data.loc[lesion_id].drop('patient_id',  errors='ignore').values), dtype=torch.float32)

        if self.has_patient_data and not self.inner_only and not self.image_only:
            return lesion_tabular, patient_tabular, image, lesion_target
        elif (self.has_patient_data and self.inner_only) or not self.has_patient_data:
            return lesion_tabular, image, lesion_target
        elif self.image_only:
            return image, lesion_target

class FeatureDataset(torch.utils.data.Dataset):
    def __init__(self, lesion_feats, patient_feats, image_feats, targets):
        self.lesion_feats = lesion_feats
        self.patient_feats = patient_feats
        self.image_feats = image_feats
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return (
            self.lesion_feats[idx],
            self.patient_feats[idx],
            self.image_feats[idx],
            self.targets[idx]
        )
