import os
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import List

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from .fl_dataset import FederatedDataset
from .utils import VisionDataset_FL

logger = logging.getLogger(__name__)


# ----------------------------
# Helpers
# ----------------------------

def read_femnist_dir(data_dir: str):
    clients = []
    groups = []
    data = defaultdict(lambda: None)

    files = [f for f in os.listdir(data_dir) if f.endswith(".json")]
    for f in files:
        file_path = os.path.join(data_dir, f)
        with open(file_path, "r") as inf:
            cdata = json.load(inf)
        clients.extend(cdata["users"])
        if "hierarchies" in cdata:
            groups.extend(cdata["hierarchies"])
        data.update(cdata["user_data"])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def _ensure_3ch(x: torch.Tensor) -> torch.Tensor:
    # x: CxHxW (float tensor). Repeat to 3 channels if grayscale.
    if x.ndim == 3 and x.size(0) == 1:
        return x.repeat(3, 1, 1)
    return x


def femnistTransformation(augment):
    # MNIST/FEMNIST stats applied per-channel (after 1->3 conversion)
    mean = (0.1307, 0.1307, 0.1307)
    std = (0.3081, 0.3081, 0.3081)

    if augment == "jit":
        # Minimal: dataset does ToTensor; JIT pipeline (client) will handle normalize.
        return transforms.Compose([
            transforms.ToTensor(),
        ])
    elif augment:
        return transforms.Compose([
            transforms.Resize((32, 32)),
            # Avoid flips for characters; small affine jitter is safer.
            transforms.RandomAffine(degrees=8, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Lambda(_ensure_3ch),
            transforms.Normalize(mean, std),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Lambda(_ensure_3ch),
            transforms.Normalize(mean, std),
        ])


def _save_server_train_val(root_dir, train_x_list, train_y_list, val_ratio=0.1, seed=123):
    """Concatenate server train arrays and split into train/val."""
    if train_x_list:
        gx = np.concatenate(train_x_list, axis=0)
        gy = np.concatenate(train_y_list, axis=0)
    else:
        gx = np.empty((0, 28, 28, 1), dtype=np.uint8)
        gy = np.empty((0,), dtype=np.int64)

    n = gx.shape[0]
    if n == 0 or val_ratio <= 0.0:
        # no val; keep all as train
        torch.save((gx, gy), os.path.join(root_dir, "train.pt"), pickle_protocol=4)
        return

    rng = np.random.RandomState(seed)
    perm = rng.permutation(n)
    n_val = max(1, int(round(val_ratio * n))) if n > 1 else 0
    val_idx, tr_idx = perm[:n_val], perm[n_val:]

    torch.save((gx[tr_idx], gy[tr_idx]), os.path.join(root_dir, "train.pt"), pickle_protocol=4)
    torch.save((gx[val_idx], gy[val_idx]), os.path.join(root_dir, "val.pt"),   pickle_protocol=4)


# ----------------------------
# Dataset
# ----------------------------

class FemnistDataset(FederatedDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # How many clients you want to simulate (from config)
        self.num_clients = int(self.ckp.config.simulation.num_clients)

        # JIT transforms (applied in client training loops when augment == 'jit')
        self.jit_augment = transforms.Compose([
            transforms.Lambda(_ensure_3ch),
            transforms.Normalize((0.1307, 0.1307, 0.1307),
                                 (0.3081, 0.3081, 0.3081)),
        ])
        self.jit_normalize = transforms.Compose([
            transforms.Lambda(_ensure_3ch),
            transforms.Normalize((0.1307, 0.1307, 0.1307),
                                 (0.3081, 0.3081, 0.3081)),
        ])

        # Build partitions once (any global file missing → build)
        server_test = os.path.join(self.dataset_fl_root, "test.pt")
        server_train = os.path.join(self.dataset_fl_root, "train.pt")
        if not (os.path.exists(server_test) and os.path.exists(server_train)):
            self.create_fl_partitions()

    # -----------------------------------------------------
    # Partition creation (LEAF JSON or HF fallback)
    # -----------------------------------------------------
    def create_fl_partitions(self):
        """Build FEMNIST client shards + global server train/val/test sets.

        Prefers LEAF JSONs under <path_to_data>/{train,test}.
        If not found, falls back to Hugging Face 'flwrlabs/femnist'
        with a deterministic 80/20 per-writer split and manual image decoding.
        """
        print("Creating FEMNIST partitions...")

        os.makedirs(self.dataset_fl_root, exist_ok=True)

        # config-driven val ratio/seed for server split
        try:
            val_ratio = float(self.ckp.config.data.args.server_val_ratio)
        except Exception:
            val_ratio = 0.1
        try:
            seed = int(self.ckp.config.seed)
        except Exception:
            seed = 123

        # ---------- 1) Try LEAF JSONs ----------
        root = Path(self.path_to_data)
        leaf_train = root / "train"
        leaf_test = root / "test"
        if leaf_train.exists() and leaf_test.exists():
            train_data_dir = str(leaf_train)
            test_data_dir = str(leaf_test)

            train_clients, _, train_data = read_femnist_dir(train_data_dir)
            test_clients, _, test_data = read_femnist_dir(test_data_dir)
            assert set(train_clients) == set(test_clients), "Train/test client sets differ in FEMNIST data"

            total_clients = len(train_clients)
            if self.num_clients > total_clients:
                raise ValueError(f"Requested num_clients={self.num_clients} exceeds available users={total_clients}")

            selected_clients = sorted(train_clients)[: self.num_clients]
            print(f"Partitioning FEMNIST with {len(selected_clients)} clients (LEAF JSON)")

            train_size = 0
            test_size = 0
            # accumulate global sets
            server_train_x, server_train_y = [], []
            server_test_x, server_test_y = [], []

            for cid, client_key in enumerate(selected_clients):
                client_dir = os.path.join(self.dataset_fl_root, str(cid))
                os.makedirs(client_dir, exist_ok=True)
                train_pt = os.path.join(client_dir, "train.pt")
                test_pt = os.path.join(client_dir, "test.pt")

                if not (os.path.exists(train_pt) and os.path.exists(test_pt)):
                    tdata = train_data[client_key]
                    vdata = test_data[client_key]

                    xtr = (np.array(tdata["x"]) * 255).astype(np.uint8).reshape(-1, 28, 28, 1)
                    ytr = np.array(tdata["y"], dtype=np.int64)
                    xte = (np.array(vdata["x"]) * 255).astype(np.uint8).reshape(-1, 28, 28, 1)
                    yte = np.array(vdata["y"], dtype=np.int64)

                    torch.save((xtr, ytr), train_pt)
                    torch.save((xte, yte), test_pt)
                else:
                    xtr, ytr = torch.load(train_pt)
                    xte, yte = torch.load(test_pt)

                # sizes
                train_size += xtr.shape[0]
                test_size += xte.shape[0]

                if xtr.shape[0] > 0:
                    server_train_x.append(xtr); server_train_y.append(ytr)
                if xte.shape[0] > 0:
                    server_test_x.append(xte);  server_test_y.append(yte)

            print(f"Train samples: {train_size} | Test samples: {test_size}")

            # Save global server train/val
            _save_server_train_val(
                self.dataset_fl_root,
                server_train_x, server_train_y,
                val_ratio=val_ratio, seed=seed
            )

            # Save global server test
            if server_test_x:
                gx = np.concatenate(server_test_x, axis=0)
                gy = np.concatenate(server_test_y, axis=0)
            else:
                gx = np.empty((0, 28, 28, 1), dtype=np.uint8)
                gy = np.empty((0,), dtype=np.int64)
            torch.save((gx, gy), os.path.join(self.dataset_fl_root, "test.pt"), pickle_protocol=4)
            return

        # ---------- 2) HF fallback (disable HF's internal PIL decoding) ----------
        print("LEAF JSON not found; downloading 'flwrlabs/femnist' from Hugging Face...")
        from datasets import load_dataset, Image as DSImage  # pip install datasets pillow
        from io import BytesIO
        from PIL import Image

        ds = load_dataset("flwrlabs/femnist")                # all rows are in 'train'
        ds = ds.cast_column("image", DSImage(decode=False))  # avoid PIL ExifTags issue
        full = ds["train"]

        # Group indices by writer_id
        by_writer = defaultdict(list)
        for i, ex in enumerate(full):
            wid = str(ex["writer_id"])
            by_writer[wid].append(i)

        writers = sorted(by_writer.keys())
        if self.num_clients > len(writers):
            raise ValueError(f"Requested num_clients={self.num_clients} exceeds available users={len(writers)}")

        selected = writers[: self.num_clients]
        print(f"Partitioning FEMNIST with {len(selected)} clients (Hugging Face)")

        # Deterministic 80/20 per-writer split
        rng = np.random.RandomState(123)
        test_ratio = 0.2

        def _img_from_idx(idx: int) -> np.ndarray:
            raw = full[idx]["image"]["bytes"]   # {'bytes': ...}
            im = Image.open(BytesIO(raw)).convert("L")
            return np.array(im, dtype=np.uint8)[:, :, None]  # HxWx1

        def _label_from_idx(idx: int) -> int:
            ex = full[idx]
            return int(ex["character"]) if "character" in ex else int(ex["label"])

        train_size = 0
        test_size = 0
        server_train_x, server_train_y = [], []
        server_test_x,  server_test_y  = [], []

        for cid, wid in enumerate(selected):
            idxs = by_writer[wid][:]
            rng.shuffle(idxs)
            n = len(idxs)
            n_test = max(1, int(round(test_ratio * n))) if n > 1 else 0
            test_idx = idxs[:n_test]
            train_idx = idxs[n_test:]

            xtr = np.stack([_img_from_idx(i) for i in train_idx]).astype(np.uint8) if train_idx else np.empty((0, 28, 28, 1), dtype=np.uint8)
            ytr = np.array([_label_from_idx(i) for i in train_idx], dtype=np.int64) if train_idx else np.empty((0,), dtype=np.int64)
            xte = np.stack([_img_from_idx(i) for i in test_idx]).astype(np.uint8) if test_idx else np.empty((0, 28, 28, 1), dtype=np.uint8)
            yte = np.array([_label_from_idx(i) for i in test_idx], dtype=np.int64) if test_idx else np.empty((0,), dtype=np.int64)

            client_dir = os.path.join(self.dataset_fl_root, str(cid))
            os.makedirs(client_dir, exist_ok=True)
            torch.save((xtr, ytr), os.path.join(client_dir, "train.pt"))
            torch.save((xte, yte), os.path.join(client_dir, "test.pt"))

            train_size += xtr.shape[0]
            test_size  += xte.shape[0]
            if xtr.shape[0] > 0:
                server_train_x.append(xtr); server_train_y.append(ytr)
            if xte.shape[0] > 0:
                server_test_x.append(xte);  server_test_y.append(yte)

        print(f"Train samples: {train_size} | Test samples: {test_size}")

        # Save global server train/val
        _save_server_train_val(
            self.dataset_fl_root,
            server_train_x, server_train_y,
            val_ratio=val_ratio, seed=seed
        )

        # Save global server test
        if server_test_x:
            gx = np.concatenate(server_test_x, axis=0)
            gy = np.concatenate(server_test_y, axis=0)
        else:
            gx = np.empty((0, 28, 28, 1), dtype=np.uint8)
            gy = np.empty((0,), dtype=np.int64)
        torch.save((gx, gy), os.path.join(self.dataset_fl_root, "test.pt"), pickle_protocol=4)

    # -----------------------------------------------------
    # Standard Dataset API
    # -----------------------------------------------------
    def download(self):
        # Expect LEAF FEMNIST to be pre-downloaded; nothing to do here.
        return

    def get_available_training_clients(self) -> List[int]:
        return list(range(self.num_clients))

    def get_dataloader(
        self,
        data_pool,
        partition,
        batch_size,
        num_workers,
        augment,
        shuffle=False,
        cid=None,
        path=None,
        val_ratio=0.0,
        seed=None,
        **kwargs,
    ):
        """
        Build a DataLoader for FEMNIST.

        Pools:
          - 'server' uses global {train,val,test}.pt under dataset_fl_root
          - 'train'/'test' use per-client shards; with cid=None, fall back to global {train,val,test}.pt

        partition: 'train' | 'val' | 'test'
        """
        data_pool = data_pool.lower()
        partition = partition.lower()
        assert data_pool in ("server", "train", "test"), "Data pool must be in server, train, or test"
        assert partition in ("train", "val", "test"), "Partition must be train, val, or test"

        # Resolve path to the .pt file
        # Resolve path to the .pt file
        if path is not None and os.path.exists(path):
            prefix_path = path if cid is None else os.path.join(path, str(cid))
            pt_path = os.path.join(prefix_path, f"{partition}.pt")
        else:
            if data_pool == "server":
                assert cid is None
                pt_path = os.path.join(self.dataset_fl_root, f"{partition}.pt")
            elif data_pool in ("train", "test"):
                # Use the same place we saved shards: dataset_fl_root/<cid>/<partition>.pt
                if cid is None:
                    # global fallback (e.g., for evaluation code that asks without cid)
                    pt_path = os.path.join(self.dataset_fl_root, f"{partition}.pt")
                else:
                    pt_path = os.path.join(self.dataset_fl_root, str(cid), f"{partition}.pt")
            else:
                raise ValueError(f"Unknown data_pool: {data_pool}")

        # If server val requested but file not present, create it now from server train
        if data_pool == "server" and partition == "val" and not os.path.exists(pt_path):
            train_pt = os.path.join(self.dataset_fl_root, "train.pt")
            if not os.path.exists(train_pt):
                raise FileNotFoundError(f"Server train.pt not found at {train_pt}")
            Xtr, Ytr = torch.load(train_pt)
            n = len(Xtr)
            if n == 0:
                # create empty val
                torch.save((Xtr, Ytr), pt_path, pickle_protocol=4)
            else:
                try:
                    val_ratio_cfg = float(self.ckp.config.data.args.server_val_ratio)
                except Exception:
                    val_ratio_cfg = 0.1
                seed_cfg = int(seed) if seed is not None else int(getattr(self.ckp.config, "seed", 123))
                rng = np.random.RandomState(seed_cfg + 999)  # offset to differ from build-time split
                perm = rng.permutation(n)
                n_val = max(1, int(round(val_ratio_cfg * n))) if n > 1 else 0
                val_idx = perm[:n_val]
                Xv, Yv = Xtr[val_idx], Ytr[val_idx]
                torch.save((Xv, Yv), pt_path, pickle_protocol=4)

        transform = femnistTransformation(augment)

        # Optional per-client validation split
        if val_ratio and partition == "train" and cid is not None:
            assert seed is not None, "Provide 'seed' for deterministic FEMNIST client val split"

            dataset = VisionDataset_FL(path_to_data=pt_path, transform=transform)
            val_len = int(val_ratio * len(dataset))
            train_len = len(dataset) - val_len

            splits = torch.utils.data.random_split(
                dataset,
                [train_len, val_len],
                generator=torch.Generator().manual_seed(seed),
            )

            return [
                DataLoader(ds, batch_size=batch_size, num_workers=num_workers,
                           pin_memory=True, drop_last=False, shuffle=shuffle, **kwargs)
                for ds in splits
            ]

        else:
            dataset = VisionDataset_FL(path_to_data=pt_path, transform=transform)
            return DataLoader(
                dataset,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=True,
                drop_last=False,
                shuffle=shuffle,
                **kwargs,
            )