# src/data/shakespeare_dataset.py

import os
import json
import math
import logging
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple, Dict, Iterable, Optional

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from .fl_dataset import FederatedDataset

logger = logging.getLogger(__name__)


# ----------------------------
# Small tensor dataset (X: [N, L], y: [N, L])
# ----------------------------

class SequenceDatasetFL(Dataset):
    def __init__(self, path_to_pt: str):
        super().__init__()
        self.X, self.y = torch.load(path_to_pt)
        if not torch.is_tensor(self.X):
            self.X = torch.as_tensor(self.X, dtype=torch.long)
        else:
            self.X = self.X.long()
        if not torch.is_tensor(self.y):
            self.y = torch.as_tensor(self.y, dtype=torch.long)
        else:
            self.y = self.y.long()

    def __len__(self):
        return int(self.X.shape[0])

    def __getitem__(self, idx: int):
        x = self.X[idx]
        y = self.y[idx]
        # NEW: if y is a sequence (T,), use the last token as the class id
        if y.ndim > 0:
            y = y[-1]
        return x, y


# ----------------------------
# Helpers
# ----------------------------

def _save_server_train_val(
    root_dir: str,
    X_list: List[np.ndarray],
    y_list: List[np.ndarray],
    val_ratio: float = 0.1,
    seed: int = 123,
    max_train: Optional[int] = None,
) -> None:
    """Concatenate server-train shards, optionally cap, and split into train/val."""
    if X_list:
        X = np.concatenate(X_list, axis=0)
        y = np.concatenate(y_list, axis=0)
    else:
        X = np.empty((0, 80), dtype=np.int64)
        y = np.empty((0, 80), dtype=np.int64)

    if isinstance(max_train, (int, float)) and max_train is not None and max_train > 0 and X.shape[0] > max_train:
        rng = np.random.RandomState(seed)
        sel = rng.choice(X.shape[0], size=int(max_train), replace=False)
        X, y = X[sel], y[sel]

    n = X.shape[0]
    if n == 0 or val_ratio <= 0.0:
        torch.save((X, y), os.path.join(root_dir, "train.pt"), pickle_protocol=4)
        return

    rng = np.random.RandomState(seed)
    perm = rng.permutation(n)
    n_val = max(1, int(round(val_ratio * n))) if n > 1 else 0
    vidx, tidx = perm[:n_val], perm[n_val:]

    torch.save((X[tidx], y[tidx]), os.path.join(root_dir, "train.pt"), pickle_protocol=4)
    torch.save((X[vidx], y[vidx]), os.path.join(root_dir, "val.pt"),   pickle_protocol=4)


def _leaf_read_dir_json(data_dir: str):
    """Read LEAF-style JSON shards (users → {'x': [...], 'y': [...]})"""
    clients, groups = [], []
    data = defaultdict(lambda: None)

    files = [f for f in os.listdir(data_dir) if f.endswith(".json")]
    for f in files:
        with open(os.path.join(data_dir, f), "r") as inf:
            cdata = json.load(inf)
        clients.extend(cdata.get("users", []))
        if "hierarchies" in cdata:
            groups.extend(cdata["hierarchies"])
        data.update(cdata.get("user_data", {}))

    clients = list(sorted(data.keys()))
    return clients, groups, data


def _build_vocab_from_clients(train_data: Dict, test_data: Optional[Dict] = None) -> Tuple[Dict[str, int], List[str]]:
    charset = set()
    def _touch(sample):
        if isinstance(sample, str):
            charset.update(list(sample))
        elif isinstance(sample, Iterable):
            pass

    for store in (train_data, test_data):
        if not store:
            continue
        for user in store.keys():
            xs = store[user].get("x", [])
            ys = store[user].get("y", [])
            for s in xs:
                _touch(s)
            for s in ys:
                _touch(s if isinstance(s, str) else "")

    vocab = sorted(charset)
    # reserve specials at front
    itos = ["<pad>", "<unk>"] + vocab
    stoi = {ch: i for i, ch in enumerate(itos)}
    return stoi, itos


def _encode_seq(x: Iterable, stoi: Dict[str, int]) -> np.ndarray:
    """x can be str or list[int]. Returns np[int64] array."""
    if isinstance(x, str):
        return np.array([stoi.get(ch, stoi["<unk>"]) for ch in x], dtype=np.int64)
    arr = np.asarray(list(x))
    if np.issubdtype(arr.dtype, np.integer):
        return arr.astype(np.int64)
    return np.array([stoi.get(str(ch), stoi["<unk>"]) for ch in arr], dtype=np.int64)


def _shard_indices(idxs: List[int], n_shards: int) -> List[List[int]]:
    """Split a list of indices into n_shards contiguous shards (balanced)."""
    n = len(idxs)
    if n_shards <= 1 or n == 0:
        return [idxs]
    step = math.ceil(n / n_shards)
    return [idxs[i:i+step] for i in range(0, n, step)]


def _tokenize_and_chunk_ds(
    ds_split,
    text_key: str,
    rows_idx: List[int],
    seq_len: int,
    char2id: Dict[str, int],
    stride: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """Concatenate rows → one stream; make overlapping (or non-overlapping) windows.

    Returns X:[N,L], y:[N,L] with y == X (model shifts internally).
    """
    if not rows_idx:
        return np.empty((0, seq_len), dtype=np.int64), np.empty((0, seq_len), dtype=np.int64)

    text = "".join(str(ds_split[i][text_key]) for i in rows_idx)
    ids = np.fromiter((char2id.get(c, char2id["<unk>"]) for c in text), dtype=np.int64)

    L = int(seq_len)
    S = int(stride) if stride and stride > 0 else L
    if ids.size < L:
        return np.empty((0, L), dtype=np.int64), np.empty((0, L), dtype=np.int64)

    # windows starting at 0, S, 2S, ...
    starts = range(0, ids.size - L + 1, S)
    X = np.stack([ids[s:s+L] for s in starts], axis=0) if ids.size >= L else np.empty((0, L), dtype=np.int64)
    if X.size == 0:
        return X, X
    return X, X  # y == X (targets are shifted in the loss)


def _cap_sequences(
    X: np.ndarray,
    Y: np.ndarray,
    cap: Optional[int],
    seed: int = 123,
) -> Tuple[np.ndarray, np.ndarray]:
    """Randomly cap number of sequences to `cap` (if provided)."""
    if not isinstance(cap, (int, float)) or cap is None or cap <= 0:
        return X, Y
    if X.shape[0] <= cap:
        return X, Y
    rng = np.random.RandomState(seed)
    sel = rng.choice(X.shape[0], size=int(cap), replace=False)
    return X[sel], Y[sel]


# ----------------------------
# Dataset
# ----------------------------

class ShakespeareDataset(FederatedDataset):
    """
    Federated char-level next-token dataset.

    Saves:
      - Per client: dataset_fl_root/<cid>/{train,test}.pt  (X:[N,L], y:[N,L] with y==X)
      - Server:     dataset_fl_root/{train,val,test}.pt
      - Vocab:      dataset_fl_root/vocab.json ({"stoi": {...}, "itos": [...]})
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.seq_len = int(getattr(self.ckp.config.data.args, "seq_len", 80))
        self.num_clients = int(self.ckp.config.simulation.num_clients)

        # (Re)build if missing or reset requested
        need = [os.path.join(self.dataset_fl_root, f) for f in ("train.pt", "test.pt", "vocab.json")]
        reset = bool(getattr(self.ckp.config.data.args, "reset", False))
        if reset or not all(os.path.exists(p) for p in need):
            self.create_fl_partitions()
        self.pre_partition = True
        # load vocab metadata
        vpath = os.path.join(self.dataset_fl_root, "vocab.json")
        try:
            with open(vpath, "r") as f:
                meta = json.load(f)
            self.stoi = {str(k): int(v) for k, v in meta.get("stoi", {}).items()}
            self.itos = [str(x) for x in meta.get("itos", [])]
        except Exception:
            self.stoi, self.itos = {}, []
        self.vocab_size = int(len(self.itos))

    # -----------------------------------------------------
    # Partition creation (LEAF JSON or HF speaker-slicing)
    # -----------------------------------------------------
    def create_fl_partitions(self):
        print("Creating Shakespeare partitions...")

        os.makedirs(self.dataset_fl_root, exist_ok=True)

        # config-driven val ratio/seed/stride/caps
        try:
            val_ratio = float(self.ckp.config.data.args.server_val_ratio)
        except Exception:
            val_ratio = 0.1
        try:
            seed = int(self.ckp.config.seed)
        except Exception:
            seed = 123

        seq_len    = int(self.seq_len)
        seq_stride = int(getattr(self.ckp.config.data.args, "seq_stride", seq_len))  # default: no overlap
        max_server_sequences      = getattr(self.ckp.config.data.args, "max_server_sequences", None)
        max_client_sequences      = getattr(self.ckp.config.data.args, "max_client_sequences", None)
        max_client_test_sequences = getattr(self.ckp.config.data.args, "max_client_test_sequences", None)
        if isinstance(max_server_sequences, (float, int)):      max_server_sequences = int(max_server_sequences)
        if isinstance(max_client_sequences, (float, int)):      max_client_sequences = int(max_client_sequences)
        if isinstance(max_client_test_sequences, (float, int)): max_client_test_sequences = int(max_client_test_sequences)

        root = Path(self.path_to_data)
        leaf_train = root / "train"
        leaf_test  = root / "test"

        # ---------- 1) LEAF JSON path ----------
        if leaf_train.exists() and leaf_test.exists():
            tr_clients, _, tr_data = _leaf_read_dir_json(str(leaf_train))
            te_clients, _, te_data = _leaf_read_dir_json(str(leaf_test))
            assert set(tr_clients) == set(te_clients), "Train/test client sets differ in Shakespeare LEAF data"

            total = len(tr_clients)
            target_K = int(self.ckp.config.simulation.num_clients)
            if total == 0:
                raise ValueError("No users in LEAF Shakespeare JSONs.")

            # Build vocab from LEAF (reserve specials)
            stoi, itos = _build_vocab_from_clients(tr_data, te_data)
            self._persist_vocab(stoi, itos)

            # If fewer users than desired, speaker-slice each user to reach K.
            shards_per_user = max(1, math.ceil(target_K / total))
            print(f"LEAF found {total} users; speaker-slicing with shards_per_user={shards_per_user}")

            server_trX, server_trY = [], []
            server_teX, server_teY = [], []
            train_size = test_size = 0
            built = 0
            rng = np.random.RandomState(123)
            test_ratio = 0.2

            users_sorted = sorted(tr_clients)
            for u in users_sorted:
                if built >= target_K:
                    break
                # Reconstruct this user's token stream from its windows to allow re-windowing with stride
                xs = tr_data[u].get("x", []) + te_data[u].get("x", [])
                txt = "".join(str(s) for s in xs)
                ids = np.array([stoi.get(ch, stoi["<unk>"]) for ch in txt], dtype=np.int64)
                if ids.size < seq_len:
                    continue

                # Create row indices as blocks in this user's stream (just emulate "rows")
                # then shard those pseudo-rows.
                # One "row" here will be an index into non-overlapping chunks of size ≈seq_len for stable sharding.
                pseudo_rows = list(range(0, max(1, ids.size // seq_len)))
                shards = _shard_indices(pseudo_rows, shards_per_user)

                for shard in shards:
                    if built >= target_K:
                        break
                    if not shard:
                        continue

                    # Build start/end positions for this shard
                    start = min(s * seq_len for s in shard)
                    end   = min(ids.size, max((s+1) * seq_len for s in shard))
                    block = ids[start:end]
                    if block.size < seq_len:
                        continue

                    # 80/20 shard split
                    split = int(round(0.8 * block.size))
                    train_blk = block[:split]
                    test_blk  = block[split:] if split < block.size else block[:0]

                    # Re-window with configured stride
                    def _mk(seq):
                        if seq.size < seq_len: return np.empty((0, seq_len), np.int64)
                        starts = range(0, seq.size - seq_len + 1, seq_stride)
                        return np.stack([seq[s:s+seq_len] for s in starts], axis=0) if seq.size >= seq_len else np.empty((0, seq_len), np.int64)

                    xtr = _mk(train_blk)
                    xte = _mk(test_blk)

                    # y == x (model shifts internally)
                    ytr, yte = xtr, xte

                    # caps
                    xtr, ytr = _cap_sequences(xtr, ytr, max_client_sequences,      seed=seed + built)
                    xte, yte = _cap_sequences(xte, yte, max_client_test_sequences, seed=seed + built)

                    if xtr.shape[0] == 0 and xte.shape[0] == 0:
                        continue

                    cdir = os.path.join(self.dataset_fl_root, str(built))
                    os.makedirs(cdir, exist_ok=True)
                    torch.save((xtr, ytr), os.path.join(cdir, "train.pt"))
                    torch.save((xte, yte), os.path.join(cdir, "test.pt"))

                    train_size += xtr.shape[0]
                    test_size  += xte.shape[0]
                    if xtr.size: server_trX.append(xtr); server_trY.append(ytr)
                    if xte.size: server_teX.append(xte); server_teY.append(yte)

                    built += 1

            print(f"Built {built} clients via LEAF speaker-slicing | Train seqs: {train_size} | Test seqs: {test_size}")
            self.num_clients = built

            # Save global
            _save_server_train_val(
                self.dataset_fl_root,
                server_trX, server_trY,
                val_ratio=val_ratio, seed=seed,
                max_train=max_server_sequences,
            )

            if server_teX:
                gx = np.concatenate(server_teX, axis=0)
                gy = np.concatenate(server_teY, axis=0)
            else:
                gx = np.empty((0, seq_len), dtype=np.int64)
                gy = np.empty((0, seq_len), dtype=np.int64)
            torch.save((gx, gy), os.path.join(self.dataset_fl_root, "test.pt"), pickle_protocol=4)
            return

        # ---------- 2) HF fallback: speaker-slicing (no Dirichlet) ----------
        print("LEAF JSON not found; downloading 'flwrlabs/shakespeare' from Hugging Face...")
        from datasets import load_dataset
        ds = load_dataset("flwrlabs/shakespeare")
        full = ds["train"]
        cols = set(full.column_names)

        # user/text keys (allow hint)
        user_key_hint = getattr(self.ckp.config.data.args, "user_key_hint", None)

        def _pick(keys):
            for k in keys:
                if k in cols:
                    return k
            return None

        user_key = user_key_hint if (user_key_hint in cols) else _pick(
            ["character", "speaker", "role", "user_id", "userid", "writer_id"]
        )
        text_key = _pick(["text", "content", "line", "sentence", "x"])
        if text_key is None:
            raise ValueError(f"Could not find a text column in {cols}")

        by_user = defaultdict(list)
        if user_key is not None:
            for i, ex in enumerate(full):
                by_user[str(ex[user_key])].append(i)
        else:
            by_user["GLOBAL"] = list(range(len(full)))

        users = sorted(by_user.keys())
        available = len(users)
        target_K = int(self.ckp.config.simulation.num_clients)

        # Build vocabulary (reserve specials; optionally cap to model vocab_size)
        try:
            model_vocab_size = int(self.ckp.config.models.net.args.vocab_size)
        except Exception:
            model_vocab_size = None

        all_text = "".join(str(full[i][text_key]) for i in range(len(full)))
        chars, freqs = np.unique(list(all_text), return_counts=True)
        base_vocab = ["<pad>", "<unk>"] + [c for _, c in sorted(zip(-freqs, chars))]
        if model_vocab_size is not None and len(base_vocab) > model_vocab_size:
            base_vocab = base_vocab[:model_vocab_size]
        char2id = {ch: i for i, ch in enumerate(base_vocab)}
        if "<pad>" not in char2id:
            base_vocab = ["<pad>"] + base_vocab
            char2id = {ch: i for i, ch in enumerate(base_vocab)}
        if "<unk>" not in char2id:
            base_vocab = base_vocab + ["<unk>"]
            char2id["<unk>"] = len(base_vocab) - 1

        # Persist vocab
        with open(os.path.join(self.dataset_fl_root, "vocab.json"), "w") as f:
            json.dump({"itos": base_vocab, "stoi": char2id, "seq_len": seq_len}, f)

        # Decide shards per user to reach K
        shards_per_user_cfg = getattr(self.ckp.config.data.args, "virtual_clients_per_user", None)
        shards_per_user = int(shards_per_user_cfg) if shards_per_user_cfg else max(1, math.ceil(target_K / max(1, available)))
        print(f"Using speaker-slicing: users={available}, shards_per_user={shards_per_user}, stride={seq_stride}")

        server_train_x, server_train_y = [], []
        server_test_x,  server_test_y  = [], []
        train_size = test_size = 0
        built = 0

        rng = np.random.RandomState(123)

        for u in users:
            if built >= target_K:
                break
            idxs = by_user[u][:]
            if not idxs:
                continue
            rng.shuffle(idxs)

            # Split user rows into shards
            shards = _shard_indices(idxs, shards_per_user)
            for shard in shards:
                if built >= target_K:
                    break
                if not shard:
                    continue

                xtr, ytr = _tokenize_and_chunk_ds(full, text_key, shard[int(0.2*len(shard)) :], seq_len, char2id, stride=seq_stride)  # ~80%
                xte, yte = _tokenize_and_chunk_ds(full, text_key, shard[:int(0.2*len(shard))],  seq_len, char2id, stride=seq_stride)  # ~20%

                # caps
                xtr, ytr = _cap_sequences(xtr, ytr, max_client_sequences,      seed=seed + built)
                xte, yte = _cap_sequences(xte, yte, max_client_test_sequences, seed=seed + built)

                if xtr.shape[0] == 0 and xte.shape[0] == 0:
                    continue

                cdir = os.path.join(self.dataset_fl_root, str(built))
                os.makedirs(cdir, exist_ok=True)
                torch.save((xtr, ytr), os.path.join(cdir, "train.pt"))
                torch.save((xte, yte), os.path.join(cdir, "test.pt"))

                train_size += xtr.shape[0]
                test_size  += xte.shape[0]
                if xtr.size: server_train_x.append(xtr); server_train_y.append(ytr)
                if xte.size: server_test_x.append(xte);  server_test_y.append(yte)

                built += 1

        print(f"Built {built} clients via speaker-slicing | Train seqs: {train_size} | Test seqs: {test_size}")
        self.num_clients = built  # reflect what we actually built

        # Save global server train/val/test with caps
        _save_server_train_val(
            self.dataset_fl_root,
            server_train_x, server_train_y,
            val_ratio=val_ratio, seed=seed,
            max_train=max_server_sequences,
        )

        if server_test_x:
            gx = np.concatenate(server_test_x, axis=0)
            gy = np.concatenate(server_test_y, axis=0)
        else:
            gx = np.empty((0, seq_len), dtype=np.int64)
            gy = np.empty((0, seq_len), dtype=np.int64)
        torch.save((gx, gy), os.path.join(self.dataset_fl_root, "test.pt"), pickle_protocol=4)

    # ---------- LEAF user record → arrays ----------
    def _leaf_user_to_arrays(self, user_dict: Dict, stoi: Dict[str, int]) -> Tuple[np.ndarray, np.ndarray]:
        """LEAF Shakespeare stores pre-windowed 'x' (strings of len L), 'y' (next char). We use y==x."""
        xs = user_dict.get("x", [])
        Xs = []
        for x in xs:
            x_ids = _encode_seq(x, stoi)
            if x_ids.shape[0] != self.seq_len:
                continue  # skip mismatched
            Xs.append(x_ids)
        if Xs:
            X = np.stack(Xs, axis=0)
            return X, X  # y == x
        L = int(self.seq_len)
        return np.empty((0, L), dtype=np.int64), np.empty((0, L), dtype=np.int64)

    def _persist_vocab(self, stoi: Dict[str, int], itos: List[str]) -> None:
        meta = {"stoi": stoi, "itos": itos, "seq_len": self.seq_len}
        with open(os.path.join(self.dataset_fl_root, "vocab.json"), "w") as f:
            json.dump(meta, f)

    # -----------------------------------------------------
    # Standard Dataset API
    # -----------------------------------------------------
    def download(self):
        # Expect LEAF Shakespeare to be pre-downloaded; HF fallback handled in create_fl_partitions.
        return

    def get_available_training_clients(self) -> List[int]:
        return list(range(self.num_clients))

    def get_dataloader(
        self,
        data_pool,
        partition,
        batch_size,
        num_workers,
        augment,            # unused (kept for signature parity)
        shuffle=False,
        cid=None,
        path=None,
        val_ratio=0.0,      # optional per-client split from train
        seed=None,
        **kwargs,
    ):
        """
        Build a DataLoader for Shakespeare.

        Pools:
          - 'server' -> global {train,val,test}.pt under dataset_fl_root
          - 'train'/'test' -> per-client shards at dataset_fl_root/<cid>/{train,test}.pt
                               (if cid is None, falls back to global .pt)
        Partitions: 'train' | 'val' | 'test'
        """
        data_pool = str(data_pool).lower()
        partition = str(partition).lower()
        assert data_pool in ("server", "train", "test"), "Data pool must be in server, train, or test"
        assert partition in ("train", "val", "test"), "Partition must be train, val, or test"

        # Resolve path
        if path is not None and os.path.exists(path):
            prefix = path if cid is None else os.path.join(path, str(cid))
            pt_path = os.path.join(prefix, f"{partition}.pt")
        else:
            if data_pool == "server":
                assert cid is None
                pt_path = os.path.join(self.dataset_fl_root, f"{partition}.pt")
            else:
                if cid is None:
                    pt_path = os.path.join(self.dataset_fl_root, f"{partition}.pt")
                else:
                    pt_path = os.path.join(self.dataset_fl_root, str(cid), f"{partition}.pt")

        # If server val requested but not present, derive from server train now
        if data_pool == "server" and partition == "val" and not os.path.exists(pt_path):
            train_pt = os.path.join(self.dataset_fl_root, "train.pt")
            if not os.path.exists(train_pt):
                raise FileNotFoundError(f"Server train.pt not found at {train_pt}")
            Xtr, Ytr = torch.load(train_pt)
            n = len(Xtr)
            if n == 0:
                torch.save((Xtr, Ytr), pt_path, pickle_protocol=4)
            else:
                try:
                    val_ratio_cfg = float(self.ckp.config.data.args.server_val_ratio)
                except Exception:
                    val_ratio_cfg = 0.1
                seed_cfg = int(seed) if seed is not None else int(getattr(self.ckp.config, "seed", 123))
                rng = np.random.RandomState(seed_cfg + 999)
                perm = rng.permutation(n)
                n_val = max(1, int(round(val_ratio_cfg * n))) if n > 1 else 0
                val_idx = perm[:n_val]
                Xv, Yv = Xtr[val_idx], Ytr[val_idx]
                torch.save((Xv, Yv), pt_path, pickle_protocol=4)

        # Optional per-client validation split from client train
        if val_ratio and partition == "train" and cid is not None:
            assert seed is not None, "Provide 'seed' for deterministic client val split"

            base_ds = SequenceDatasetFL(pt_path)
            n = len(base_ds)
            vlen = int(val_ratio * n)
            tlen = n - vlen
            g = torch.Generator().manual_seed(seed)
            tds, vds = torch.utils.data.random_split(base_ds, [tlen, vlen], generator=g)
            return [
                DataLoader(ds, batch_size=batch_size, num_workers=num_workers,
                           pin_memory=True, drop_last=False, shuffle=shuffle, **kwargs)
                for ds in (tds, vds)
            ]

        ds = SequenceDatasetFL(pt_path)
        return DataLoader(
            ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=False,
            shuffle=shuffle,
            **kwargs,
        )