import torch
from torch.utils.data import Dataset
import pickle
import pandas as pd
import os
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import sys
from copy import deepcopy


class PickleDataset(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'rb') as f:
            self.data, self.classes, self.class_to_idx = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item

class HAM10000DatasetBalanced(Dataset):
    def __init__(self, csv_file="/data1/home/ict04/data/surgical/HAM10000/HAM10000_metadata.csv", 
                                img_dir="/data1/home/ict04/data/surgical/HAM10000/img", transform=None):
        self.skin_df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        data_transform = transforms.Compose(
            [transforms.Resize((224, 224), interpolation=Image.Resampling.BILINEAR), 
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
        )
        self.transform = transform
        if self.transform is None:
            self.transform = data_transform
        self.class_to_idx = {
            'nv': 0,      # Melanocytic nevi
            'mel': 1,     # Melanom`a
            'bkl': 2,     # Benign keratosis-like lesions
            'bcc': 3,     # Basal cell carcinoma
            'akiec': 4,   # Actinic keratoses
            'vasc': 5,    # Vascular lesions
            'df': 6       # Dermatofibroma
        }
        self.classes = list(self.class_to_idx.keys())

    def __len__(self):
        return len(self.skin_df)

    def label_to_int(self, label):
        return self.class_to_idx.get(label, -1)


    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.skin_df.iloc[idx, 1] + '.jpg')
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        # image = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
        label = self.label_to_int(self.skin_df.iloc[idx, 2])
        return image, label


# PACS Domain Generalization Dataset Classes
def find_classes(dir_name):
    """Find classes in a directory"""
    if sys.version_info >= (3, 5):
        # Faster and available in Python 3.5 and above
        classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
    else:
        classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def get_random_subset(names, labels, percent):
    """
    Get random subset for train/val split
    :param names: list of names
    :param labels:  list of labels
    :param percent: 0 < float < 1
    :return:
    """
    from random import sample
    samples = len(names)
    amount = int(samples * percent)
    random_index = sample(range(samples), amount)
    name_val = [names[k] for k in random_index]
    name_train = [v for k, v in enumerate(names) if k not in random_index]
    labels_val = [labels[k] for k in random_index]
    labels_train = [v for k, v in enumerate(labels) if k not in random_index]
    return name_train, name_val, labels_train, labels_val


def get_split_domain_info_from_dir(domain_path, val_percentage=0.1):
    """Get split domain info from directory structure"""
    names, labels = [], []
    classes, class_to_idx = find_classes(domain_path)
    
    for i, class_name in enumerate(classes):
        class_path = os.path.join(domain_path, class_name)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    names.append(os.path.join(class_name, img_name))
                    labels.append(i)
    
    name_train, name_val, labels_train, labels_val = get_random_subset(names, labels, val_percentage)
    return name_train, name_val, labels_train, labels_val, classes


class PACSDataset(Dataset):
    """PACS Domain Generalization Dataset"""
    
    def __init__(self, names, labels, dataset_path, img_transformer=None):
        self.names = names
        self.labels = labels
        self.dataset_path = dataset_path
        self._image_transformer = img_transformer
        
        if self._image_transformer is None:
            self._image_transformer = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.dataset_path, self.names[index])
        img = Image.open(img_path).convert('RGB')
        img = self._image_transformer(img)
        label = self.labels[index]
        return img, label


class PACSMultiDomainDataset(Dataset):
    """PACS Multi-Domain Dataset for Domain Generalization"""
    
    def __init__(self, data_root, source_domains, target_domain, val_percentage=0.1, mode='train'):
        """
        Args:
            data_root: Root directory containing PACS data
            source_domains: List of source domains for training
            target_domain: Target domain for testing
            val_percentage: Validation split percentage
            mode: 'train', 'val', or 'test'
        """
        self.data_root = data_root
        self.source_domains = source_domains
        self.target_domain = target_domain
        self.mode = mode
        
        # PACS domain names and classes
        self.pacs_domains = ["art_painting", "cartoon", "photo", "sketch"]
        self.pacs_classes = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
        self.classes = self.pacs_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Image transformations
        if mode == 'train':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        
        # Load data
        self.data = []
        self.labels = []
        
        if mode in ['train', 'val']:
            # Load source domains
            for domain in source_domains:
                domain_path = os.path.join(data_root, domain)
                if os.path.exists(domain_path):
                    name_train, name_val, labels_train, labels_val, classes = get_split_domain_info_from_dir(
                        domain_path, val_percentage
                    )
                    
                    if mode == 'train':
                        for name, label in zip(name_train, labels_train):
                            self.data.append(os.path.join(domain_path, name))
                            self.labels.append(label)
                    else:  # val
                        for name, label in zip(name_val, labels_val):
                            self.data.append(os.path.join(domain_path, name))
                            self.labels.append(label)
        
        elif mode == 'test':
            # Load target domain
            domain_path = os.path.join(data_root, target_domain)
            if os.path.exists(domain_path):
                name_train, name_val, labels_train, labels_val, classes = get_split_domain_info_from_dir(
                    domain_path, val_percentage
                )
                # Use all data for testing (both train and val splits)
                for name, label in zip(name_train + name_val, labels_train + labels_val):
                    self.data.append(os.path.join(domain_path, name))
                    self.labels.append(label)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        label = self.labels[idx]
        return img, label