import torch
import shutil
import os
import numpy as np

from pathlib import Path
from abc import ABC, abstractmethod
from sklearn.model_selection import train_test_split
from flwr.dataset.utils.common import shuffle, sort_by_label, split_array_at_indices, sample_without_replacement

import logging
logger = logging.getLogger(__name__)

# import numpy as np
from typing import List, Tuple, Optional, Union
# from flwr.dataset.utils.common import shuffle, sort_by_label, split_array_at_indices, sample_without_replacement

XY = Tuple[np.ndarray, np.ndarray]

def create_lda_partitions(
    dataset: XY,
    dirichlet_dist: Optional[np.ndarray],
    quantity_skew_dirichlet: Optional[np.ndarray] = None,
    num_partitions: int = 100,
    concentration: Union[float, np.ndarray] = 0.5,
    accept_imbalanced: bool = False,
    seed: Optional[int] = None,
):
    x, y = dataset
    # Make sure we work with numpy (works fine even if tensors/lists were passed)
    x, y = np.asarray(x), np.asarray(y)
    x, y = shuffle(x, y)
    x, y = sort_by_label(x, y)

    if (x.shape[0] % num_partitions) and (not accept_imbalanced):
        raise ValueError(
            "Total number of samples must be a multiple of `num_partitions` "
            "or set `accept_imbalanced=True`."
        )

    # Per-client sample counts
    if quantity_skew_dirichlet is None:
        num_samples = [0] * num_partitions
        for j in range(x.shape[0]):
            num_samples[j % num_partitions] += 1
    else:
        total = x.shape[0] - (2 * num_partitions)
        if total < 0:
            raise ValueError(
                f"Dataset too small ({x.shape[0]}) for quantity skew with {num_partitions} clients."
            )
        rng_q = np.random.default_rng(seed)
        num_samples = rng_q.multinomial(total, quantity_skew_dirichlet)
        num_samples = (num_samples + 2).tolist()  # at least 2 samples per client
        assert sum(num_samples) == x.shape[0]

    # Classes and boundaries (y is sorted already)
    classes, start_indices = np.unique(y, return_index=True)

    # Normalize concentration to per-class vector
    concentration = np.asarray(concentration, dtype=float)
    if np.isinf(concentration).any():
        # IID case: uniform label distribution per client
        dirichlet_dist = np.full((num_partitions, classes.size), 1.0 / classes.size, dtype=float)
    else:
        if concentration.size == 1:
            concentration = np.repeat(concentration, classes.size)
        elif concentration.size != classes.size:
            raise ValueError(
                f"Concentration size ({concentration.size}) must be 1 or equal to num classes ({classes.size})."
            )
        if dirichlet_dist is None:
            dirichlet_dist = np.random.default_rng(seed).dirichlet(
                alpha=concentration, size=num_partitions
            )
        else:
            dirichlet_dist = np.asarray(dirichlet_dist, dtype=float)

        if dirichlet_dist.shape != (num_partitions, classes.size):
            raise ValueError(
                f"dirichlet_dist shape {dirichlet_dist.shape} != (num_partitions, num_classes) "
                f"= ({num_partitions}, {classes.size})"
            )

    # Split into list-of-arrays per class (indices are starts; function handles slicing)
    list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices(x, start_indices)

    # Sample for each client
    partitions: List[XY] = [None] * num_partitions  # type: ignore
    empty_classes = [False] * classes.size
    for pid in range(num_partitions):
        partitions[pid], empty_classes = sample_without_replacement(
            distribution=dirichlet_dist[pid].copy(),
            list_samples=list_samples_per_class,
            num_samples=num_samples[pid],
            empty_classes=empty_classes,
        )

    return partitions, dirichlet_dist

class FederatedDataset(ABC):
    def __init__(self, ckp, path_to_data, dataset_fl_root, *args, 
        pre_partition=False,
        lda_alpha=None,
        val_ratio=0, 
        train_alpha=None, 
        test_alpha=None, 
        quantity_train_alpha=None,
        reset=False, **kwargs):
        if hasattr(ckp.config, 'seed'):
            np.random.seed(ckp.config.seed)

        self.ckp = ckp
        self.config = ckp.config
        self.pool_size = ckp.config.simulation.num_clients
        self.pre_partition = pre_partition
        self.path_to_data = path_to_data
        self.dataset_fl_root = dataset_fl_root
        self.lda_alpha = lda_alpha
        self.train_alpha = train_alpha
        self.test_alpha = test_alpha 
        self.quantity_train_alpha = quantity_train_alpha       
        self.val_ratio = val_ratio
        self.reset = reset
        self.partitions = ['train.pt', 'test.pt']

        if not self.pre_partition:
            assert self.lda_alpha is not None, 'dataset is not pre-partitioned. data.args.lda_alpha must be defined.'
            # sort lda_alpha as loading from saved dict might violate the order
            self.lda_alpha = dict(sorted(self.lda_alpha.items(), reverse=True))
            self.lda_alpha = {str(key): value for key, value in self.lda_alpha.items()}

            if self.train_alpha is None:
                self.train_alpha = list(self.lda_alpha.keys())
            self.train_alpha = list(map(lambda x: str(x), self.train_alpha))
            self.train_alpha = sorted(self.train_alpha, reverse=True)
            if self.test_alpha is None:
                self.test_alpha = self.lda_alpha
            self.test_alpha = dict(sorted(self.test_alpha.items(), reverse=True))

            for a in self.train_alpha:
                assert str(float(a)) in self.lda_alpha or str(int(a)) in self.lda_alpha, f'Train alpha ({a}) must be found in lda_alpha ({self.lda_alpha})'
            assert self.pool_size == sum(self.lda_alpha.values()) == sum(self.test_alpha.values()), \
                'Num of clients must match total no. of clients defined in data.args.lda_alpha and data.args.test_alpha'

            self.fed_train_dir = self.get_fed_dir(self.lda_alpha)
            self.fed_test_dir = self.get_fed_dir(self.test_alpha)

    def create_fl_partitions(self):
        assert not self.pre_partition, 'FL dataset is pre-partitioned. Do not recreate fl partitions'
        self._create_fl_partition(self.lda_alpha)
        self._create_fl_partition(self.test_alpha)
        logger.info(f'Lda alpha:{self.lda_alpha}. Test alpha: {self.test_alpha}. Training with {len(self.get_available_training_clients())} clients.')

    def get_fed_dir(self, alpha_dict):
        name = ''
        for alpha, num_clients in alpha_dict.items():
            name += f'{alpha}_{num_clients}_'
        if self.quantity_train_alpha is not None:
            name += f'QT{self.quantity_train_alpha}_'
        name += f'valratio{self.val_ratio}'

        return os.path.join(self.dataset_fl_root, str(self.pool_size), name)

    def get_available_training_clients(self):
        if not self.pre_partition:
            available_clients = []
            start = 0
            for alpha, num_clients in self.lda_alpha.items():
                if alpha in self.train_alpha:
                    available_clients += list(range(start, start + num_clients))
                start += num_clients

            return available_clients
        else:
            raise NotImplementedError # overwrite this method if pre-partitioned

    def _create_fl_partition(self, alpha_dict):
        dir_path = self.get_fed_dir(alpha_dict)
        os.umask(0)

        if self.reset and os.path.exists(dir_path):
            logger.info(f'Reset flag is set for data federated splitting.. Deleting current {dir_path}')
            shutil.rmtree(dir_path)

        # If we already have a complete FL partition for this alpha mapping, just exit
        if self._has_fl_partition(dir_path):
            logger.info(f"FL partitioned dataset {dir_path} found.")
            return

        self.download()
        logger.info(f"Creating FL partitioned dataset {dir_path}..")

        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)
        os.makedirs(dir_path, exist_ok=True)

        # ensure the global FL root exists (we'll drop a root-level val.pt pointer there)
        os.makedirs(self.dataset_fl_root, exist_ok=True)

        fixed_seed = self.config.seed if hasattr(self.config, 'seed') else 42

        dirichlet_dists = {}
        quantity_skew_dirichlet_dists = {}
        for alpha in alpha_dict.keys():
            dirichlet_dists[alpha] = None
            if self.quantity_train_alpha is not None:
                rng = np.random.default_rng(fixed_seed)
                quantity_skew_dirichlet_dists[alpha] = rng.dirichlet(
                    np.ones(alpha_dict[alpha]) * self.quantity_train_alpha
                )
            else:
                quantity_skew_dirichlet_dists[alpha] = None

        # helper: (re)create a pointer at <FL root>/val.pt to the freshly created split val.pt
        def _refresh_root_val_pointer(val_path_abs: str):
            root_val = os.path.join(self.dataset_fl_root, "val.pt")
            target_rel = os.path.relpath(val_path_abs, self.dataset_fl_root)

            def _same_target(p: str, target: str) -> bool:
                try:
                    return os.path.islink(p) and os.path.normpath(os.readlink(p)) == os.path.normpath(target)
                except OSError:
                    return False

            # If anything already exists (file or (even broken) symlink) and it doesn't point to the
            # right place, remove it so we can replace it.
            if os.path.lexists(root_val) and not _same_target(root_val, target_rel):
                try:
                    os.remove(root_val)  # removes file or symlink
                except OSError:
                    pass

            # Create a fresh symlink when possible; fall back to file copy where symlinks aren’t supported
            try:
                if not os.path.lexists(root_val):
                    os.symlink(target_rel, root_val)
            except OSError:
                shutil.copyfile(val_path_abs, root_val)

        for partition in self.partitions:
            alpha_data = {}
            raw_data_dir = os.path.join(self.path_to_data, f'{partition}')
            server_data_dir = os.path.join(self.dataset_fl_root, f'{partition}')

            # Keep a copy of the raw (global) partition at the FL root once
            if not os.path.exists(server_data_dir):
                shutil.copyfile(raw_data_dir, server_data_dir)

            X_raw, Y_raw = torch.load(raw_data_dir)

            # Build central validation split once from train, then drop a root-level pointer (symlink or copy)
            if partition == 'train.pt' and self.val_ratio > 0.:
                val_path = os.path.join(dir_path, "val.pt")
                X_raw, X_val, Y_raw, y_val = train_test_split(
                    X_raw, Y_raw, test_size=self.val_ratio, stratify=Y_raw, shuffle=True, random_state=fixed_seed
                )
                torch.save([X_val, y_val], val_path)
                print("val_path", val_path)
                _refresh_root_val_pointer(os.path.abspath(val_path))

            # ---- split this partition into alpha-groups ----
            unallocated_size = self.pool_size
            for alpha, num_clients in alpha_dict.items():
                ratio = num_clients / unallocated_size
                if ratio == 1:
                    X_group, Y_group = X_raw, Y_raw
                else:
                    X_raw, X_group, Y_raw, Y_group = train_test_split(
                        X_raw, Y_raw, test_size=ratio, stratify=Y_raw, random_state=fixed_seed
                    )
                alpha_data[alpha] = (X_group, Y_group)
                unallocated_size -= num_clients

            # 2) cleaner client indexing (continuous across alpha groups, consistent for train/test)
            client_offset = 0
            for alpha, group_data in alpha_data.items():
                num_clients = alpha_dict[alpha]
                qs_dirichlet = quantity_skew_dirichlet_dists[alpha] if 'train' in partition else None

                client_partitions, dirichlet_dist = create_lda_partitions(
                    dataset=group_data,
                    dirichlet_dist=dirichlet_dists[alpha],
                    quantity_skew_dirichlet=qs_dirichlet,
                    num_partitions=num_clients,
                    concentration=float(alpha),
                    accept_imbalanced=True,
                    seed=fixed_seed,
                )
                dirichlet_dists[alpha] = dirichlet_dist  # reuse for test to match label skew

                for i, cp in enumerate(client_partitions):
                    cid = client_offset + i
                    client_path = os.path.join(dir_path, str(cid))
                    os.makedirs(client_path, exist_ok=True)
                    torch.save(cp, os.path.join(client_path, partition))

                client_offset += num_clients

            # ---- split each alpha-group across its clients using LDA (and optional quantity skew) ----
            # idx = -1
            # for alpha, group_data in alpha_data.items():
            #     num_clients = alpha_dict[alpha]
            #     qs_dirichlet = quantity_skew_dirichlet_dists[alpha] if 'train' in partition else None

            #     client_partitions, dirichlet_dist = create_lda_partitions(
            #         dataset=group_data,
            #         dirichlet_dist=dirichlet_dists[alpha],
            #         quantity_skew_dirichlet=qs_dirichlet,
            #         num_partitions=num_clients,
            #         concentration=float(alpha),
            #         accept_imbalanced=True,
            #         seed=fixed_seed,
            #     )
            #     dirichlet_dists[alpha] = dirichlet_dist

            #     # save per-client tensors
            #     for idx, cp in enumerate(client_partitions, idx + 1):
            #         client_path = os.path.join(dir_path, str(idx))
            #         os.makedirs(client_path, exist_ok=True)
            #         torch.save(cp, os.path.join(client_path, partition))

   
    def _has_fl_partition(self, dir_path):
        if self.val_ratio > 0 and not os.path.exists(os.path.join(self.dataset_fl_root, 'val.pt')):
            return False
        
        for cid in range(self.pool_size):
            for partition in self.partitions:
                file_path = os.path.join(dir_path, str(cid), partition)
                if not os.path.exists(file_path):
                    return False    

        return True

    @abstractmethod
    def download(self):
        '''
        Downloads the dataset to self.path_to_data
        '''

    @abstractmethod
    def get_dataloader(self, 
                    data_pool, # server/train/test
                    partition,
                    *args,
                    cid=None, 
                    **kwargs):
        '''
        Class-specific dataloader
        '''
        




