"""Download data and partition data with different partitioning strategies."""

import os
import json
from collections import defaultdict
from typing import List, Optional, Tuple, Any, DefaultDict, Dict


import h5py
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.utils.data as data
from torch.utils.data import ConcatDataset, Dataset, Subset
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, CIFAR100

from sklearn.model_selection import train_test_split


## Tiny imageNet Dataset
import os, csv, zipfile, urllib.request, shutil
from glob import glob
from typing import Tuple, Optional, List
from PIL import Image

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


TIN_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"

def ensure_tiny_imagenet(root: str = "data", url: str = TIN_URL, force: bool = False) -> str:
    """
    Ensure Tiny-ImageNet-200 exists at <root>/tiny-imagenet-200.
    If missing, download the official zip and extract atomically.
    Returns the absolute path to the dataset root.
    """
    target_dir = os.path.join(root, "tiny-imagenet-200")
    anno = os.path.join(target_dir, "val", "val_annotations.txt")
    if os.path.isdir(target_dir) and os.path.isfile(anno) and not force:
        return os.path.abspath(target_dir)

    os.makedirs(root, exist_ok=True)
    zip_path = os.path.join(root, "tiny-imagenet-200.zip")

    if force or not os.path.isfile(zip_path):
        def _progress(block_num, block_size, total_size):
            if total_size > 0:
                downloaded = block_num * block_size
                pct = int(downloaded * 100 / total_size)
                print(f"\rDownloading Tiny-ImageNet… {pct:3d}%", end="")
        print(f"Fetching Tiny-ImageNet from {url}")
        urllib.request.urlretrieve(url, zip_path, _progress)
        print("\nDownload complete:", zip_path)

    # Safe, atomic extract
    with zipfile.ZipFile(zip_path, "r") as zf:
        for m in zf.namelist():
            if ".." in m or m.startswith("/"):
                raise RuntimeError("Unsafe path in zip (path traversal)")
        tmpdir = os.path.join(root, "_extract_tmp")
        if os.path.isdir(tmpdir):
            shutil.rmtree(tmpdir)
        zf.extractall(tmpdir)
        src = os.path.join(tmpdir, "tiny-imagenet-200")
        if not os.path.isdir(src):
            raise RuntimeError("Archive did not contain tiny-imagenet-200/")
        if os.path.isdir(target_dir):
            shutil.rmtree(target_dir)
        shutil.move(src, target_dir)
        shutil.rmtree(tmpdir)

    return os.path.abspath(target_dir)

class TinyImageNet(Dataset):
    """
    Tiny-ImageNet-200 (64x64) wrapper with 200 classes.
    Expects folder structure as in the official zip:
      tiny-imagenet-200/
        train/<wnid>/images/*.JPEG
        val/images/*.JPEG
        val/val_annotations.txt
    """
    def __init__(self, root: str, split: str = "train",
                 transform: Optional[transforms.Compose] = None):
        assert split in {"train", "val", "test"}
        self.root = os.path.join(root, "tiny-imagenet-200")
        self.split = split
        self.transform = transform

        if split == "train":
            wnids = sorted(os.listdir(os.path.join(self.root, "train")))
            self.wnid_to_idx = {wnid: i for i, wnid in enumerate(wnids)}
            self.samples: List[Tuple[str, int]] = []
            for wnid in wnids:
                imgs = glob(os.path.join(self.root, "train", wnid, "images", "*.JPEG"))
                self.samples.extend([(p, self.wnid_to_idx[wnid]) for p in imgs])

        elif split == "val":
            # map image name -> wnid from annotations
            anno_file = os.path.join(self.root, "val", "val_annotations.txt")
            with open(anno_file, "r") as f:
                reader = csv.reader(f, delimiter="\t")
                img2wnid = {row[0]: row[1] for row in reader}
            wnids = sorted(set(img2wnid.values()))
            self.wnid_to_idx = {wnid: i for i, wnid in enumerate(wnids)}
            val_img_dir = os.path.join(self.root, "val", "images")
            imgs = sorted(glob(os.path.join(val_img_dir, "*.JPEG")))
            self.samples = [(p, self.wnid_to_idx[img2wnid[os.path.basename(p)]])
                            for p in imgs if os.path.basename(p) in img2wnid]

        else:  # test (no labels)
            imgs = sorted(glob(os.path.join(self.root, "test", "images", "*.JPEG")))
            self.wnid_to_idx = {}
            self.samples = [(p, -1) for p in imgs]

        self.classes = list(range(200))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, target


def _download_data(dataset_name="cifar10", model_type: Optional[str] = None) -> Tuple[Dataset, Dataset]:
    """Download the requested dataset. Currently supports cifar10, mnist, and fmnist.

    Returns
    -------
    Tuple[Dataset, Dataset]
        The training dataset, the test dataset.
    """
    trainset, testset = None, None
    if dataset_name == "cifar10":
        if model_type == "vit":
            transform_train = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
            transform_test = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: F.pad(
                        Variable(x.unsqueeze(0), requires_grad=False),
                        (4, 4, 4, 4),
                        mode="reflect",
                    ).data.squeeze()
                ),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])

        trainset = CIFAR10(
            root="data",
            train=True,
            download=True,
            transform=transform_train,
        )
        testset = CIFAR10(
            root="data",
            train=False,
            download=True,
            transform=transform_test,
        )
    elif dataset_name == "cifar100":
        transform_train = transforms.Compose(
            [
                transforms.Pad(4, padding_mode="reflect"),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ]
        )
        trainset = CIFAR100(
            root="data",
            train=True,
            download=True,
            transform=transform_train,
        )
        testset = CIFAR100(
            root="data",
            train=False,
            download=True,
            transform=transform_test,
        )

    elif dataset_name == "mnist":
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        trainset = MNIST(
            root="data",
            train=True,
            download=True,
            transform=transform_train,
        )
        testset = MNIST(
            root="data",
            train=False,
            download=True,
            transform=transform_test,
        )
    elif dataset_name == "fmnist":
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        trainset = FashionMNIST(
            root="data",
            train=True,
            download=True,
            transform=transform_train,
        )
        testset = FashionMNIST(
            root="data",
            train=False,
            download=True,
            transform=transform_test,
        )
    elif dataset_name in {"tiny_imagenet", "tinyimagenet", "tiny-imagenet"}:
        # For ResNet-34/50 use 224x224 pipeline
        data_root = "data"  # or read from config/env
        ensure_tiny_imagenet(data_root)
        if (model_type or "").lower() in {"resnet34", "resnet50", "resnet", ""}:
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])
            transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])
        else:
            # default: keep 224 to reuse ImageNet weights; 
            # change to RandomCrop(64) / CenterCrop(64) and adjust the model stem.
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])
            transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])

        # root="data" should contain "tiny-imagenet-200" after you extract the zip
        trainset = TinyImageNet(root="data", split="train", transform=transform_train)
        testset  = TinyImageNet(root="data", split="val",   transform=transform_test)
    else:
        raise NotImplementedError

    return trainset, testset


# pylint: disable=too-many-locals
def partition_data(
    num_clients, similarity=1.0, seed=42, dataset_name="cifar10", model_type: Optional[str] = None
) -> Tuple[List[Dataset], Dataset]:
    """Partition the dataset into subsets for each client.

    Parameters
    ----------
    num_clients : int
        The number of clients that hold a part of the data
    similarity: float
        Parameter to sample similar data
    seed : int, optional
        Used to set a fix seed to replicate experiments, by default 42

    Returns
    -------
    Tuple[List[Subset], Dataset]
        The list of datasets for each client, the test dataset.
    """
    trainset, testset = _download_data(dataset_name, model_type=model_type)
    trainsets_per_client = []
    # for s% similarity sample iid data per client
    s_fraction = int(similarity * len(trainset))
    prng = np.random.default_rng(seed)
    idxs = prng.choice(len(trainset), s_fraction, replace=False)
    iid_trainset = Subset(trainset, idxs)
    rem_trainset = Subset(trainset, np.setdiff1d(np.arange(len(trainset)), idxs))

    # sample iid data per client from iid_trainset
    all_ids = np.arange(len(iid_trainset))
    splits = np.array_split(all_ids, num_clients)
    for i in range(num_clients):
        c_ids = splits[i]
        d_ids = iid_trainset.indices[c_ids]
        trainsets_per_client.append(Subset(iid_trainset.dataset, d_ids))

    if similarity == 1.0:
        return trainsets_per_client, testset

    tmp_t = rem_trainset.dataset.targets
    if isinstance(tmp_t, list):
        tmp_t = np.array(tmp_t)
    if isinstance(tmp_t, torch.Tensor):
        tmp_t = tmp_t.numpy()
    targets = tmp_t[rem_trainset.indices]
    num_remaining_classes = len(set(targets))
    remaining_classes = list(set(targets))
    client_classes: List[List] = [[] for _ in range(num_clients)]
    times = [0 for _ in range(num_remaining_classes)]

    for i in range(num_clients):
        client_classes[i] = [remaining_classes[i % num_remaining_classes]]
        times[i % num_remaining_classes] += 1
        j = 1
        while j < 2:
            index = prng.choice(num_remaining_classes)
            class_t = remaining_classes[index]
            if class_t not in client_classes[i]:
                client_classes[i].append(class_t)
                times[index] += 1
                j += 1

    rem_trainsets_per_client: List[List] = [[] for _ in range(num_clients)]

    for i in range(num_remaining_classes):
        class_t = remaining_classes[i]
        idx_k = np.where(targets == i)[0]
        prng.shuffle(idx_k)
        idx_k_split = np.array_split(idx_k, times[i])
        ids = 0
        for j in range(num_clients):
            if class_t in client_classes[j]:
                act_idx = rem_trainset.indices[idx_k_split[ids]]
                rem_trainsets_per_client[j].append(
                    Subset(rem_trainset.dataset, act_idx)
                )
                ids += 1

    for i in range(num_clients):
        trainsets_per_client[i] = ConcatDataset(
            [trainsets_per_client[i]] + rem_trainsets_per_client[i]
        )

    return trainsets_per_client, testset


def partition_data_dirichlet(
    num_clients, alpha, seed=42, dataset_name="cifar10", model_type: Optional[str] = None
) -> Tuple[List[Dataset], Dataset]:
    """Partition according to the Dirichlet distribution.

    Parameters
    ----------
    num_clients : int
        The number of clients that hold a part of the data
    alpha: float
        Parameter of the Dirichlet distribution
    seed : int, optional
        Used to set a fix seed to replicate experiments, by default 42
    dataset_name : str
        Name of the dataset to be used

    Returns
    -------
    Tuple[List[Subset], Dataset]
        The list of datasets for each client, the test dataset.
    """
    trainset, testset = _download_data(dataset_name, model_type=model_type)
    min_required_samples_per_client = 10
    min_samples = 0
    prng = np.random.default_rng(seed)

    # get the targets
    if not hasattr(trainset, "targets") and hasattr(trainset, "samples"):
        trainset.targets = [s[1] for s in trainset.samples]
    tmp_t = trainset.targets

    if isinstance(tmp_t, list):
        tmp_t = np.array(tmp_t)
    if isinstance(tmp_t, torch.Tensor):
        tmp_t = tmp_t.numpy()
    num_classes = len(set(tmp_t))
    total_samples = len(tmp_t)
    while min_samples < min_required_samples_per_client:
        idx_clients: List[List] = [[] for _ in range(num_clients)]
        for k in range(num_classes):
            idx_k = np.where(tmp_t == k)[0]
            prng.shuffle(idx_k)
            proportions = prng.dirichlet(np.repeat(alpha, num_clients))
            proportions = np.array(
                [
                    p * (len(idx_j) < total_samples / num_clients)
                    for p, idx_j in zip(proportions, idx_clients)
                ]
            )
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_k_split = np.split(idx_k, proportions)
            idx_clients = [
                idx_j + idx.tolist() for idx_j, idx in zip(idx_clients, idx_k_split)
            ]
            min_samples = min([len(idx_j) for idx_j in idx_clients])

    trainsets_per_client = [Subset(trainset, idxs) for idxs in idx_clients]
    return trainsets_per_client, testset


def partition_data_label_quantity(
    num_clients, labels_per_client, seed=42, dataset_name="cifar10"
) -> Tuple[List[Dataset], Dataset]:
    """Partition the data according to the number of labels per client.

    Logic from https://github.com/Xtra-Computing/NIID-Bench/.

    Parameters
    ----------
    num_clients : int
        The number of clients that hold a part of the data
    num_labels_per_client: int
        Number of labels per client
    seed : int, optional
        Used to set a fix seed to replicate experiments, by default 42
    dataset_name : str
        Name of the dataset to be used

    Returns
    -------
    Tuple[List[Subset], Dataset]
        The list of datasets for each client, the test dataset.
    """
    trainset, testset = _download_data(dataset_name)
    prng = np.random.default_rng(seed)

    targets = trainset.targets
    if isinstance(targets, list):
        targets = np.array(targets)
    if isinstance(targets, torch.Tensor):
        targets = targets.numpy()
    num_classes = len(set(targets))
    times = [0 for _ in range(num_classes)]
    contains = []

    for i in range(num_clients):
        current = [i % num_classes]
        times[i % num_classes] += 1
        j = 1
        while j < labels_per_client:
            index = prng.choice(num_classes, 1)[0]
            if index not in current:
                current.append(index)
                times[index] += 1
                j += 1
        contains.append(current)
    idx_clients: List[List] = [[] for _ in range(num_clients)]
    for i in range(num_classes):
        idx_k = np.where(targets == i)[0]
        prng.shuffle(idx_k)
        idx_k_split = np.array_split(idx_k, times[i])
        ids = 0
        for j in range(num_clients):
            if i in contains[j]:
                idx_clients[j] += idx_k_split[ids].tolist()
                ids += 1
    trainsets_per_client = [Subset(trainset, idxs) for idxs in idx_clients]
    return trainsets_per_client, testset

def _read_dataset(path: str) -> Tuple[List, DefaultDict, List]:
    """Read (if necessary) and returns the leaf dataset.

    Parameters
    ----------
        path : str
            The path where the leaf dataset was downloaded

    Returns
    -------
    Tuple[user, data[x,y], num_total_data]
        The dataset for training and the dataset for testing.
    """
    users = []
    data: DefaultDict[str, Any] = defaultdict(lambda: None)
    num_example = []

    files = [f for f in os.listdir(path) if f.endswith(".json")]

    for file_name in files:
        with open(f"{path}/{file_name}") as file:
            dataset = json.load(file)
        users.extend(dataset["users"])
        data.update(dataset["user_data"])
        num_example.extend(dataset["num_samples"])

    users = sorted(data.keys())
    return users, data, num_example


def support_query_split(
    data,
    label,
    support_ratio: float,
) -> Tuple[List, List, List, List]:
    """Separate support set and query set.

    Parameters
    ----------
        data: DefaultDict,
            Raw all Datasets
        label: List,
            Raw all Labels
        support_ratio : float
            The ratio of Support set for each client.(between 0 and 1)
            by default 0.2

    Returns
    -------
    Tuple[List, List, List, List]
        Support set and query set classification of data and labels
    """
    x_train, x_test, y_train, y_test = train_test_split(
        data, label, train_size=support_ratio, stratify=label, random_state=42
    )

    return x_train, x_test, y_train, y_test


def split_train_validation_test_clients(
    clients: List,
    train_rate: float = 0.8,
    val_rate: float = 0.1,
) -> Tuple[List[str], List[str], List[str]]:
    """Classification of all clients into train, valid, and test.

    Parameters
    ----------
        clients: List,
            Full list of clients for the sampled leaf dataset.
        train_rate: float,  optional
            The ratio of training clients to total clients
            by default 0.8
        val_rate: float,  optional
            The ratio of validation clients to total clients
            by default 0.1

    Returns
    -------
    Tuple[List, List, List]
        List of each train client, valid client, and test client
    """
    np.random.seed(42)
    train_rate = int(train_rate * len(clients))
    val_rate = int(val_rate * len(clients))

    index = np.random.permutation(len(clients))
    trans_numpy = np.asarray(clients)
    train_clients = trans_numpy[index[:train_rate]].tolist()
    val_clients = trans_numpy[index[train_rate : train_rate + val_rate]].tolist()
    test_clients = trans_numpy[index[train_rate + val_rate :]].tolist()

    return train_clients, val_clients, test_clients


# pylint: disable=too-many-locals
def _partition_data(
    data_type: str,
    dir_path: str,
    support_ratio: float,
) -> Tuple[Dict, Dict]:
    """Classification of support sets and query sets by client.

    Parameters
    ----------
        data_type: str,
            The type of femnist for classification or shakespeare for regression
        dir_path: str,
            The path where the leaf dataset was downloaded
        support_ratio: float,
            The ratio of Support set for each client.(between 0 and 1)
            by default 0.2

    Returns
    -------
    Tuple[Dict, Dict]
        Return support set and query set for total data
    """
    train_path = f"{dir_path}/train"
    test_path = f"{dir_path}/test"

    train_users, train_data, _ = _read_dataset(train_path)
    _, test_data, _ = _read_dataset(test_path)

    all_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}
    support_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}
    query_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}

    for user in train_users:
        all_x = np.asarray(train_data[user]["x"] + test_data[user]["x"])
        all_y = np.asarray(train_data[user]["y"] + test_data[user]["y"])

        if data_type == "femnist":
            unique, counts = np.unique(all_y, return_counts=True)
            class_counts = dict(zip(unique, counts))

            # Find classes with only one sample
            classes_to_remove = [
                cls for cls, count in class_counts.items() if count == 1
            ]

            # Filter out the samples of those classes
            mask = np.isin(all_y, classes_to_remove, invert=True)

            all_x = all_x[mask]
            all_y = all_y[mask]

            # Client filtering for support set and query set classification
            try:
                sup_x, qry_x, sup_y, qry_y = support_query_split(
                    all_x, all_y, support_ratio
                )
            except Exception:  # pylint: disable=broad-except
                continue

        elif data_type == "shakespeare":
            sup_x, qry_x, sup_y, qry_y = train_test_split(
                all_x, all_y, train_size=support_ratio, random_state=42
            )

        all_dataset["users"].append(user)
        all_dataset["user_data"][user] = {"x": all_x.tolist(), "y": all_y.tolist()}
        all_dataset["num_samples"].append(len(all_y.tolist()))

        support_dataset["users"].append(user)
        support_dataset["user_data"][user] = {"x": sup_x, "y": sup_y}
        support_dataset["num_samples"].append(len(sup_y))

        query_dataset["users"].append(user)
        query_dataset["user_data"][user] = {"x": qry_x, "y": qry_y}
        query_dataset["num_samples"].append(len(qry_y))

    return support_dataset, query_dataset