import numpy as np
import os
import torch


def read_data(dataset, idx, is_train=True):
    """
    Read raw data for a specific client from disk.

    Loads preprocessed data (either training or testing) for a client from JSON files
    stored in the dataset directory structure.

    Args:
        dataset (str): Name of the dataset to load
        idx (int): Index of the client for which to load data
        is_train (bool): If True, load training data; else, load testing data

    Returns:
        dict: Dictionary containing 'x' (input data) and 'y' (labels)
    """
    if is_train:
        train_data_dir = os.path.join('../dataset', dataset, 'train/')

        train_file = train_data_dir + str(idx) + '.npz'
        with open(train_file, 'rb') as f:
            train_data = np.load(f, allow_pickle=True)['data'].tolist()

        return train_data

    else:
        test_data_dir = os.path.join('../dataset', dataset, 'test/')

        test_file = test_data_dir + str(idx) + '.npz'
        with open(test_file, 'rb') as f:
            test_data = np.load(f, allow_pickle=True)['data'].tolist()

        return test_data


def read_client_data(dataset, idx, is_train=True):
    """
    Load and preprocess non-text data (e.g., images) for a specific client.

    Reads image data (either training or testing) for a client, converts it to PyTorch tensors,
    and returns a list of (input, label) tuples.

    Args:
        dataset (str): Name of the dataset to load
        idx (int): Index of the client for which to load data
        is_train (bool): If True, load training data; else, load testing data

    Returns:
        list: List of tuples where each tuple contains input tensor and label tensor
    """
    if "News" in dataset:
        return read_client_data_text(dataset, idx, is_train)
    elif "Shakespeare" in dataset:
        return read_client_data_Shakespeare(dataset, idx)

    if is_train:
        train_data = read_data(dataset, idx, is_train)
        X_train = torch.Tensor(train_data['x']).type(torch.float32)
        y_train = torch.Tensor(train_data['y']).type(torch.int64)

        train_data = [(x, y) for x, y in zip(X_train, y_train)]
        return train_data
    else:
        test_data = read_data(dataset, idx, is_train)
        X_test = torch.Tensor(test_data['x']).type(torch.float32)
        y_test = torch.Tensor(test_data['y']).type(torch.int64)
        test_data = [(x, y) for x, y in zip(X_test, y_test)]
        return test_data


def read_client_data_text(dataset, idx, is_train=True):
    """
    Load and preprocess text data for a specific client.

    Reads text data (either training or testing) for a client, processes input sequences 
    and their lengths, converts them to PyTorch tensors, and returns a list of (input, label) tuples.

    Args:
        dataset (str): Name of the dataset to load
        idx (int): Index of the client for which to load data
        is_train (bool): If True, load training data; else, load testing data

    Returns:
        list: List of tuples where each tuple contains:
            - (x, lens): Input sequence tensor (shape [seq_len]) and its length (int)
            - y: Label tensor (int)
    """
    if is_train:
        train_data = read_data(dataset, idx, is_train)
        X_train, X_train_lens = list(zip(*train_data['x']))
        y_train = train_data['y']

        X_train = torch.Tensor(X_train).type(torch.int64)
        X_train_lens = torch.Tensor(X_train_lens).type(torch.int64)
        y_train = torch.Tensor(train_data['y']).type(torch.int64)

        train_data = [((x, lens), y) for x, lens, y in zip(X_train, X_train_lens, y_train)]
        return train_data
    else:
        test_data = read_data(dataset, idx, is_train)
        X_test, X_test_lens = list(zip(*test_data['x']))
        y_test = test_data['y']

        X_test = torch.Tensor(X_test).type(torch.int64)
        X_test_lens = torch.Tensor(X_test_lens).type(torch.int64)
        y_test = torch.Tensor(test_data['y']).type(torch.int64)

        test_data = [((x, lens), y) for x, lens, y in zip(X_test, X_test_lens, y_test)]
        return test_data


def read_client_data_Shakespeare(dataset, idx, is_train=True):
    """
    Load and preprocess Shakespeare text data for a specific client.

    Reads Shakespeare data (either training or testing) for a client, converts it to PyTorch tensors,
    and returns a list of (input, label) tuples.

    Args:
        dataset (str): Name of the dataset to load
        idx (int): Index of the client for which to load data
        is_train (bool): If True, load training data; else, load testing data

    Returns:
        list: List of tuples where each tuple contains input tensor and label tensor
    """
    if is_train:
        train_data = read_data(dataset, idx, is_train)
        X_train = torch.Tensor(train_data['x']).type(torch.int64)
        y_train = torch.Tensor(train_data['y']).type(torch.int64)

        train_data = [(x, y) for x, y in zip(X_train, y_train)]
        return train_data
    else:
        test_data = read_data(dataset, idx, is_train)
        X_test = torch.Tensor(test_data['x']).type(torch.int64)
        y_test = torch.Tensor(test_data['y']).type(torch.int64)
        test_data = [(x, y) for x, y in zip(X_test, y_test)]
        return test_data

