import os
import pdb
import torch
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from partition import *


# =============================================================================
#                         Data loading
# =============================================================================
# Custom dataset prepration in Pytorch format
class SkinData(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        # X = Image.open(self.df['path'][index]).resize((64, 64))
        # y = torch.tensor(int(self.df['target'][index]))
        img_path = self.df['path'].iloc[int(index)]
        target = self.df['target'].iloc[int(index)]

        X = Image.open(img_path).resize((64, 64))
        y = torch.tensor(int(target))
        if self.transform:
            X = self.transform(X)

        return X, y


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


def data_load(args):
    # ======================== Data preprocessing ========================
    if args.chan == 3:
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    elif args.chan == 1:
        mean, std = [0.5], [0.5]

    train_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Pad(3),
        transforms.RandomRotation(10),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    test_transforms = transforms.Compose([
        transforms.Pad(3),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # ======================== Dataset loading ========================
    if args.dataset == 'HAM':
        df = pd.read_csv(args.data_path)
        lesion_type = {
            'nv': 'Melanocytic nevi',
            'mel': 'Melanoma',
            'bkl': 'Benign keratosis-like lesions ',
            'bcc': 'Basal cell carcinoma',
            'akiec': 'Actinic keratoses',
            'vasc': 'Vascular lesions',
            'df': 'Dermatofibroma'
        }

        imageid_path = {os.path.splitext(os.path.basename(x))[0]: x
                        for x in glob(os.path.join("./HAM", '*', '*.jpg'))}

        df['path'] = df['image_id'].map(imageid_path.get)
        df['cell_type'] = df['dx'].map(lesion_type.get)
        df['target'] = pd.Categorical(df['cell_type']).codes
        print(df['cell_type'].value_counts())
        print(df['target'].value_counts())

        train, test = train_test_split(df, test_size=0.2)
        train, test = train.reset_index(drop=True), test.reset_index(drop=True)

        dataset_train = SkinData(train, transform=train_transforms)
        dataset_test = SkinData(test, transform=test_transforms)



    elif args.dataset == 'ISIC':

        classes = sorted([d for d in os.listdir(args.data_path) if os.path.isdir(os.path.join(args.data_path, d))])
        paths = []
        labels = []
        for idx, cls in enumerate(classes):
            cls_dir = os.path.join(args.data_path, cls)
            for img_name in os.listdir(cls_dir):
                if img_name.lower().endswith(('.jpg', '.png', '.jpeg')):
                    paths.append(os.path.join(cls_dir, img_name))
                    labels.append(idx)

        df = pd.DataFrame({'path': paths, 'target': labels})

        train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['target'], random_state=42)
        train_df, test_df = train_df.reset_index(drop=True), test_df.reset_index(drop=True)

        dataset_train = SkinData(train_df, transform=train_transforms)
        dataset_test = SkinData(test_df, transform=test_transforms)




    elif args.dataset == 'MNIST':
        dataset_train = datasets.MNIST(root='./', train=True,
                                       transform=train_transforms, download=True)
        dataset_test = datasets.MNIST(root='./', train=False,
                                      transform=test_transforms, download=True)


    elif args.dataset == 'F-MNIST':
        dataset_train = datasets.FashionMNIST(root='./', train=True,
                                              transform=train_transforms, download=True)
        dataset_test = datasets.FashionMNIST(root='./', train=False,
                                             transform=test_transforms, download=True)

    elif args.dataset == 'CIFAR10':
        dataset_train = datasets.CIFAR10(root='./', train=True,
                                         transform=train_transforms, download=True)
        dataset_test = datasets.CIFAR10(root='./', train=False,
                                        transform=test_transforms, download=True)

    elif args.dataset == 'CIFAR100':
        dataset_train = datasets.CIFAR100(root='./', train=True,
                                          transform=train_transforms, download=True)
        dataset_test = datasets.CIFAR100(root='./', train=False,
                                         transform=test_transforms, download=True)

    # ======================== Partitioning ========================
    if args.iid:
        dict_users = dataset_iid(dataset_train, args.num_users)
        dict_users_test = dataset_iid(dataset_test, args.num_users)
    else:
        dict_users = create_noniid_indices(dataset_train, num_clients=args.num_users, alpha=args.noniid_alpha)
        dict_users_test = create_noniid_indices(dataset_test, num_clients=args.num_users, alpha=args.noniid_alpha)

    return dataset_train, dataset_test, dict_users, dict_users_test


# =====================================================================================================
# dataset_iid() will create a dictionary to collect the indices of the data samples randomly for each client
import numpy as np
from collections import defaultdict


def dataset_iid(dataset, num_users):
    num_items = int(len(dataset) / num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def create_noniid_indices(dataset, num_clients, alpha=0.5):
    labels = np.array([int(dataset[i][1]) for i in range(len(dataset))])
    num_classes = len(np.unique(labels))

    label_indices = {i: np.where(labels == i)[0] for i in range(num_classes)}

    client_indices = defaultdict(list)

    for c, idxs in label_indices.items():
        np.random.shuffle(idxs)
        proportions = np.random.dirichlet(alpha * np.ones(num_clients))
        proportions = np.array([p * len(idxs) for p in proportions]).astype(int)

        diff = len(idxs) - np.sum(proportions)
        for i in range(diff):
            proportions[i % num_clients] += 1

        start = 0
        for client_id, count in enumerate(proportions):
            client_indices[client_id].extend(idxs[start:start + count])
            start += count

    dict_users = {i: set(client_indices[i]) for i in range(num_clients)}
    return dict_users

# import numpy as np
# from collections import defaultdict
#
# def create_noniid_indices(dataset, num_clients, alpha=0.5):

#     labels = np.array([int(dataset[i][1]) for i in range(len(dataset))])
#     num_classes = len(np.unique(labels))
#
#     label_indices = {i: np.where(labels == i)[0] for i in range(num_classes)}
#     client_indices = defaultdict(list)
#
#     for c, idxs in label_indices.items():
#         np.random.shuffle(idxs)
#         proportions = np.random.dirichlet(alpha * np.ones(num_clients))
#
#         proportions = proportions / proportions.sum()
#         counts = np.floor(proportions * len(idxs)).astype(int)
#
#         for i in range(len(counts)):
#             if counts[i] == 0:
#                 counts[i] = 1
#
#         diff = len(idxs) - np.sum(counts)
#         while diff != 0:
#             if diff > 0:
#                 i = np.random.randint(0, num_clients)
#                 counts[i] += 1
#                 diff -= 1
#             else:
#                 i = np.random.randint(0, num_clients)
#                 if counts[i] > 1:
#                     counts[i] -= 1
#                     diff += 1
#
#         start = 0
#         for client_id, count in enumerate(counts):
#             client_indices[client_id].extend(idxs[start:start + count])
#             start += count
#
#     dict_users = {i: set(client_indices[i]) for i in range(num_clients)}
#
#     empty_clients = [k for k, v in dict_users.items() if len(v) == 0]
#     if len(empty_clients) > 0:
#         raise ValueError(f"Empty client indices found: {empty_clients}")
#
#     return dict_users
