import os
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torchvision, torchvision.transforms as T
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import warnings
from sklearn.cluster import KMeans

try:
    import datasets
except ImportError:
    print("Warning: 'datasets' library not found. NLPDataModule will not be available.")
    print("Please install it: pip install datasets")
    datasets = None

warnings.filterwarnings("ignore", category=FutureWarning)


class DataModule:
    def __init__(self, cfg):
        self.cfg = cfg
        dcfg = cfg.get("data", {})
        self.dataset_name = dcfg.get("dataset", "cifar100").lower()
        self.num_clients = int(dcfg.get("num_clients", 10))
        self.bs = int(dcfg.get("batch_size", 128))
        self.cache_root = dcfg.get("cache_root", "./cache")
        
        self.data_type = "vision" 
        self.transform = T.Compose([T.ToTensor()]) 
        
        self.text_column = dcfg.get("text_column", "text")
        self.label_column = dcfg.get("label_column", "label")

        self._root_dir = f"./{self.dataset_name}/data"
        
        if self.dataset_name in ["cifar100", "cifar10", "tinyimagenet", "emnist"]:
            self.data_type = "vision"
            if self.dataset_name == "cifar100":
                DS = torchvision.datasets.CIFAR100
                self.num_classes = 100
                self.train_dataset = DS(root=self._root_dir, train=True,  download=True, transform=self.transform)
                self.test_dataset  = DS(root=self._root_dir, train=False, download=True, transform=self.transform)
            elif self.dataset_name == "cifar10":
                DS = torchvision.datasets.CIFAR10
                self.num_classes = 10
                self.train_dataset = DS(root=self._root_dir, train=True,  download=True, transform=self.transform)
                self.test_dataset  = DS(root=self._root_dir, train=False, download=True, transform=self.transform)
            elif self.dataset_name == "tinyimagenet":
                self.num_classes = 200
                train_path = os.path.join(self._root_dir, 'tiny-imagenet-200', 'train')
                test_path = os.path.join(self._root_dir, 'tiny-imagenet-200', 'val')
                if not (os.path.exists(train_path) and os.path.exists(test_path)):
                    raise RuntimeError("Tiny ImageNet dataset not found or not prepared.")
                
                self.train_dataset = ImageFolder(root=train_path, transform=self.transform)
                self.test_dataset = ImageFolder(root=test_path, transform=self.transform)
            
            elif self.dataset_name == "emnist":
                DS = torchvision.datasets.EMNIST
                self.emnist_split = dcfg.get("emnist_split", "balanced").lower()
                split_classes = {
                    "balanced": 47,
                    "byclass": 62,
                    "bymerge": 47,
                    "digits": 10,
                    "letters": 26,
                    "mnist": 10
                }
                if self.emnist_split not in split_classes:
                    raise ValueError(f"Unknown EMNIST split: {self.emnist_split}. Must be one of {list(split_classes.keys())}")
                self.num_classes = split_classes[self.emnist_split]
                
                self.train_dataset = DS(root=self._root_dir, split=self.emnist_split, train=True,  download=True, transform=self.transform)
                self.test_dataset  = DS(root=self._root_dir, split=self.emnist_split, train=False, download=True, transform=self.transform)

            self.train_labels_np = np.array(self.train_dataset.targets, dtype=np.int64)

        else:
            if datasets is None:
                raise ImportError("Please install 'datasets' to use text datasets.")
            print(f"[Data] Loading text dataset: {self.dataset_name}")
            self.data_type = "text"
            
            try:
                dset = datasets.load_dataset(self.dataset_name)
            except Exception as e:
                print(f"Error loading dataset {self.dataset_name}. {e}")
                raise
                
            if "train" not in dset:
                raise ValueError(f"Text dataset {self.dataset_name} does not have a 'train' split.")
            self.train_dataset = dset["train"]
            
            if "validation" in dset:
                self.test_dataset = dset["validation"]
            elif "test" in dset:
                self.test_dataset = dset["test"]
            else:
                raise ValueError(f"Text dataset {self.dataset_name} has no 'test' or 'validation' split.")

            self.train_labels_np = np.array(self.train_dataset[self.label_column], dtype=np.int64)
            
            try:
                self.num_classes = self.train_dataset.features[self.label_column].num_classes
            except Exception:
                self.num_classes = np.max(self.train_labels_np) + 1
            print(f"[Data] Text dataset loaded. Num classes: {self.num_classes}")


        self.partition = self._normalize_partition(dcfg)
        self.client_indices = self._make_partition()

        min_per = int(self.cfg.get("data", {}).get("min_samples_per_client", 1))
        if min_per > 0 and self.partition.get("type") != "feature_skew":
            self.client_indices = self._ensure_min_per_client(self.client_indices, min_per=min_per)

        self._train_feats = None; self._train_targets = None
        self._test_feats  = None; self._test_targets  = None

        self._noisy_client_labels = {}
        self._build_label_noise(dcfg)

    def create_feature_skew_partition(self, features):
        print("[Data] Creating feature skew partition based on pre-computed features...")
        y_train = self.train_labels_np
        all_indices = np.arange(len(y_train))
        idxs_per_client = {i: [] for i in range(self.num_clients)}
        
        for c in range(self.num_classes):
            class_mask = (y_train == c)
            class_indices = all_indices[class_mask]
            
            if len(class_indices) == 0:
                continue
            
            z_c = features[class_mask].cpu().numpy()
            n_clusters = self.num_clients

            if len(class_indices) < n_clusters:
                print(f"[Data] WARN: Class {c} has {len(class_indices)} samples, less than num_clients ({n_clusters}). Distributing this class randomly.")
                shuffled_indices = np.random.permutation(class_indices)
                splits = np.array_split(shuffled_indices, self.num_clients)
                for i in range(self.num_clients):
                    idxs_per_client[i].extend(splits[i].tolist())
                continue

            kmeans = KMeans(n_clusters=n_clusters, random_state=int(self.cfg.get("seed", 0)), n_init='auto')
            cluster_labels = kmeans.fit_predict(z_c)

            for client_id in range(n_clusters):
                cluster_mask = (cluster_labels == client_id)
                indices_for_this_cluster = class_indices[cluster_mask]
                idxs_per_client[client_id].extend(indices_for_this_cluster.tolist())
        
        print("[Data] Successfully created feature_skew partition.")
        self.client_indices = idxs_per_client
        return idxs_per_client
        
    def _normalize_partition(self, dcfg):
        part = dcfg.get("partition", "dirichlet")
        if isinstance(part, dict):
            t = part.get("type", "dirichlet").lower()
            out = {"type": t}
            if t == "dirichlet":
                out["alpha"] = float(part.get("alpha", dcfg.get("dirichlet_alpha", 0.1)))
            elif t == "zipf":
                out["s"] = float(part.get("s", dcfg.get("zipf_s", 1.0)))
            elif t in ("shard1", "shard2", "iid", "feature_skew"):
                pass
            else:
                out = {"type": "dirichlet", "alpha": float(dcfg.get("dirichlet_alpha", 0.1))}
            return out
        if isinstance(part, str):
            t = part.lower()
            if t == "iid":        return {"type": "iid"}
            if t == "dirichlet":  return {"type": "dirichlet", "alpha": float(dcfg.get("dirichlet_alpha", 0.1))}
            if t in ("shard1", "shard2"): return {"type": t}
            if t == "zipf":       return {"type": "zipf", "s": float(dcfg.get("zipf_s", 1.0))}
            if t == "feature_skew": return {"type": "feature_skew"}
            return {"type": "dirichlet", "alpha": float(dcfg.get("dirichlet_alpha", 0.1))}
        return {"type": "dirichlet", "alpha": float(dcfg.get("dirichlet_alpha", 0.1))}

    def apply_encoder_transform(self, enc_transform):
        
        final_transform = enc_transform
        
        if self.dataset_name == "emnist":
            final_transform = T.Compose([
                T.Grayscale(num_output_channels=3)
            ] + enc_transform.transforms)
        
        self.transform = final_transform
        
        if self.data_type == "vision":
            self.train_dataset.transform = self.transform
            self.test_dataset.transform  = self.transform

    def client_ids(self):
        return list(range(self.num_clients))

    def _make_partition(self):
        ptype = self.partition.get("type")
        if ptype == "feature_skew":
            print("[Data] feature_skew partition will be created later (after feature pre-computation).")
            return {i: [] for i in range(self.num_clients)}

        p = self.partition
        labels = self.train_labels_np
        idxs_per_client = {i: [] for i in range(self.num_clients)}

        if ptype == "iid":
            perm = np.random.permutation(len(labels))
            chunks = np.array_split(perm, self.num_clients)
            for i, ch in enumerate(chunks):
                idxs_per_client[i] = ch.tolist()
            return idxs_per_client

        if ptype == "dirichlet":
            alpha = float(p.get("alpha", 0.1))
            class_indices = [np.where(labels == c)[0] for c in range(self.num_classes)]
            for c_idxs in class_indices:
                if len(c_idxs) == 0:
                    continue 
                np.random.shuffle(c_idxs)
                props = np.random.dirichlet([alpha] * self.num_clients)
                cut = (np.cumsum(props) * len(c_idxs)).astype(int)[:-1]
                splits = np.split(c_idxs, cut)
                for i in range(self.num_clients):
                    idxs_per_client[i].extend(splits[i].tolist())
            for i in range(self.num_clients):
                np.random.shuffle(idxs_per_client[i])
            return idxs_per_client

        if ptype in ("shard1", "shard2"):
            shards_per_client = 1 if ptype == "shard1" else 2
            total_shards = self.num_clients * shards_per_client
            order = np.argsort(labels)
            shards = np.array_split(order, total_shards)
            rng = np.random.RandomState(int(self.cfg.get("seed", 1337)))
            rng.shuffle(shards)
            for i in range(self.num_clients):
                pick = shards[i*shards_per_client:(i+1)*shards_per_client]
                acc = []
                for s in pick:
                    if s.size > 0:
                        acc.extend(s.tolist())
                idxs_per_client[i].extend(acc)
            return idxs_per_client

        if ptype == "zipf":
            s = float(p.get("s", 1.0))
            ranks = np.arange(1, self.num_clients + 1)
            probs = (1.0 / (ranks ** s)); probs /= probs.sum()
            counts = (probs * len(labels)).astype(int)
            diff = len(labels) - counts.sum()
            counts[:diff] += 1
            perm = np.random.permutation(len(labels))
            start = 0
            for i in range(self.num_clients):
                end = start + counts[i]
                idxs_per_client[i] = perm[start:end].tolist()
                start = end
            return idxs_per_client
        
        alpha = 0.1
        class_indices = [np.where(labels == c)[0] for c in range(self.num_classes)]
        for c in range(self.num_classes):
            idxs = class_indices[c]
            if len(idxs) == 0: continue
            np.random.shuffle(idxs)
            props = np.random.dirichlet([alpha] * self.num_clients)
            cut = (np.cumsum(props) * len(idxs)).astype(int)[:-1]
            splits = np.split(idxs, cut)
            for i in range(self.num_clients):
                idxs_per_client[i].extend(splits[i].tolist())
        for i in range(self.num_clients):
            np.random.shuffle(idxs_per_client[i])
        return idxs_per_client


    def _ensure_min_per_client(self, idxs_per_client, min_per=1):
        total = sum(len(v) for v in idxs_per_client.values())
        if total < self.num_clients * min_per:
            print(f"[Data] WARN: not enough samples to guarantee "
                  f"min_per={min_per} for {self.num_clients} clients (N={total}). Skipping repair.")
            return idxs_per_client

        rng = np.random.RandomState(int(self.cfg.get("seed", 1337)))
        arrs = {cid: (np.array(v, dtype=np.int64) if len(v) > 0 else np.empty((0,), dtype=np.int64))
                for cid, v in idxs_per_client.items()}

        def sizes():
            return {cid: arrs[cid].shape[0] for cid in arrs}

        need = [cid for cid, n in sizes().items() if n < min_per]
        tries = 0
        while need:
            tries += 1
            if tries > 10_000:
                print("[Data] WARN: repair loop exceeded 10k iterations. Aborting repair.")
                break
            sz = sizes()
            donors = [cid for cid, n in sz.items() if n > min_per]
            if not donors:
                print("[Data] WARN: no donors with > min_per. Stopping repair.")
                break
            donors.sort(key=lambda c: sz[c], reverse=True)
            ecid = need.pop(0); d = donors[0]
            take_idx = rng.randint(0, arrs[d].shape[0])
            val = arrs[d][take_idx]
            arrs[d] = np.delete(arrs[d], take_idx)
            arrs[ecid] = np.concatenate([arrs[ecid], np.array([val], dtype=np.int64)])
            if arrs[ecid].shape[0] < min_per:
                need.append(ecid)

        repaired = {cid: arrs[cid].tolist() for cid in arrs}
        return repaired

    def _build_label_noise(self, dcfg):
        ratio = float(dcfg.get("noise_ratio", 0.0))
        mode  = str(dcfg.get("noise_mode", "")).lower()
        if ratio <= 0.0 or mode not in ("symmetric", "asymmetric"):
            return

        seed = int(self.cfg.get("seed", 1337))
        
        if mode == "asymmetric":
            pass

        labels_np = self.train_labels_np.copy()
        
        for cid in range(self.num_clients):
            idxs = self.client_indices[cid]
            if len(idxs) == 0:
                self._noisy_client_labels[cid] = torch.empty(0, dtype=torch.long)
                continue

            y = labels_np[idxs].copy()
            rng_c = np.random.RandomState(seed + 10000 + cid)

            for c in np.unique(y):
                pos = np.where(y == c)[0]
                if pos.size == 0:
                    continue
                k = int(np.floor(pos.size * ratio))
                if k <= 0:
                    continue
                flip_local = rng_c.choice(pos, size=k, replace=False)

                if mode == "symmetric":
                    choices = np.array([x for x in range(self.num_classes) if x != c], dtype=np.int64)
                    if len(choices) == 0: continue
                    y[flip_local] = rng_c.choice(choices, size=flip_local.size, replace=True)
                else:
                    pass

            self._noisy_client_labels[cid] = torch.from_numpy(y.astype(np.int64))

        print(f"[Data] label-noise enabled: mode={mode}, ratio={ratio:.3f}")

    def _build_cifar100_coarse_maps(self):
        base = os.path.join(self._root_dir, "cifar-100-python")
        train_pkl = os.path.join(base, "train")
        meta_pkl  = os.path.join(base, "meta")
        if not (os.path.exists(train_pkl) and os.path.exists(meta_pkl)):
            print("[Data] WARN: CIFAR-100 coarse mapping files not found; run dataset download first.")
            return None, None

        def _load_pickle(path):
            with open(path, "rb") as f:
                return pickle.load(f, encoding="latin1")

        train_obj = _load_pickle(train_pkl)
        fine_labels_train   = np.array(train_obj["fine_labels"], dtype=np.int64)
        coarse_labels_train = np.array(train_obj["coarse_labels"], dtype=np.int64)

        fine_to_coarse = np.full((100,), -1, dtype=np.int64)
        for f in range(100):
            cs = np.unique(coarse_labels_train[fine_labels_train == f])
            if cs.size > 0:
                fine_to_coarse[f] = int(cs[0])

        coarse_groups = {c: [] for c in range(20)}
        for f in range(100):
            c = fine_to_coarse[f]
            if c >= 0:
                coarse_groups[c].append(f)
        for c in coarse_groups:
            coarse_groups[c] = sorted(coarse_groups[c])

        return fine_to_coarse, coarse_groups


    def maybe_precompute(self, enc_mgr):
        if not enc_mgr or not enc_mgr.can_precompute():
            return
        os.makedirs(self.cache_root, exist_ok=True)
        tag_patch = ""
        if hasattr(enc_mgr, 'patch_size') and enc_mgr.patch_size:
             tag_patch = f"_p{enc_mgr.patch_size}"
        
        if self.data_type == "text":
             tag_model = enc_mgr.model_name_or_path.split('/')[-1] 
             tag = f"{self.dataset_name}_{enc_mgr.encoder_type}_{tag_model}"
        else:
             tag = f"{self.dataset_name}_{enc_mgr.encoder_type}_{enc_mgr.model_size}{tag_patch}"

        tr_path = os.path.join(self.cache_root, f"{tag}_train.pt")
        te_path = os.path.join(self.cache_root, f"{tag}_test.pt")
        use_cache = getattr(enc_mgr, "save_cache", True)

        if use_cache and os.path.exists(tr_path) and os.path.exists(te_path):
            buf = torch.load(tr_path); self._train_feats, self._train_targets = buf["x"], buf["y"]
            buf = torch.load(te_path); self._test_feats,  self._test_targets  = buf["x"], buf["y"]
            print(f"[Data] loaded precomputed features from cache: {tag}")
            return

        print(f"[Data] precomputing features: {tag}")
        bs = int(getattr(enc_mgr, "precompute_batch_size", 256))
        dev = getattr(enc_mgr, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        if self.data_type == "vision":
            def _run_vision(ds, split_name):
                loader = DataLoader(ds, batch_size=bs, shuffle=False, num_workers=2, pin_memory=True)
                outs = []
                with torch.no_grad():
                    for x, _ in tqdm(loader, desc=f"precompute[{split_name}]", leave=True):
                        x = x.to(dev, non_blocking=True)
                        feat = enc_mgr.encode(x).detach().cpu()
                        outs.append(feat)
                X = torch.cat(outs, dim=0)
                y = torch.tensor(ds.targets, dtype=torch.long)
                assert X.shape[0] == len(y), "precompute shape mismatch with original targets"
                return X, y

            self._train_feats, self._train_targets = _run_vision(self.train_dataset, "train")
            self._test_feats,  self._test_targets  = _run_vision(self.test_dataset,  "test")
        
        elif self.data_type == "text":
            tokenizer = self.transform 
            
            def _run_text(ds, split_name):
                outs = []
                texts = ds[self.text_column]
                y = torch.tensor(ds[self.label_column], dtype=torch.long)
                
                with torch.no_grad():
                    for i in tqdm(range(0, len(texts), bs), desc=f"precompute[{split_name}]", leave=True):
                        batch_texts = texts[i : i + bs]
                        tokenized_batch = tokenizer(batch_texts) 
                        feat = enc_mgr.encode(tokenized_batch).detach().cpu()
                        outs.append(feat)
                X = torch.cat(outs, dim=0)
                assert X.shape[0] == len(y), "precompute shape mismatch with original targets"
                return X, y

            self._train_feats, self._train_targets = _run_text(self.train_dataset, "train")
            self._test_feats,  self._test_targets  = _run_text(self.test_dataset,  "test")

        if use_cache:
            torch.save({"x": self._train_feats, "y": self._train_targets}, tr_path)
            torch.save({"x": self._test_feats,  "y": self._test_targets},  te_path)

    def get_client_data(self, cid, input_type="features"):
        idxs = self.client_indices[cid]

        if not idxs:
             return (torch.empty(0), torch.empty(0, dtype=torch.long))

        if cid in self._noisy_client_labels:
            y = self._noisy_client_labels[cid]
        else:
            if input_type == "features" and self._train_feats is not None:
                y = self._train_targets[idxs]
            else:
                y = torch.tensor(self.train_labels_np[idxs], dtype=torch.long)

        if input_type == "features":
            if self._train_feats is None:
                 raise ValueError("Requested 'features' input_type, but features are not precomputed.")
            x = self._train_feats[idxs]
            return (x, y)
        
        if self.data_type == "vision":
            imgs = [self.train_dataset[i][0] for i in idxs]
            x = torch.stack(imgs)
            return (x, y)
        elif self.data_type == "text":
            x = [self.train_dataset[i][self.text_column] for i in idxs]
            return (x, y)

    def get_testset(self, input_type="features"):
        if input_type == "features":
            if self._test_feats is None:
                raise ValueError("Requested 'features' input_type for test set, but features are not precomputed.")
            return (self._test_feats, self._test_targets)

        if self.data_type == "vision":
            loader = DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False, num_workers=2)
            x_all, _ = next(iter(loader))
            y_all = torch.tensor(self.test_dataset.targets, dtype=torch.long)
            return (x_all, y_all)
        elif self.data_type == "text":
            y_all = torch.tensor(self.test_dataset[self.label_column], dtype=torch.long)
            x_all = self.test_dataset[self.text_column]
            return (x_all, y_all)