import numpy as np
from pathlib import Path
from sklearn.decomposition import PCA
import os
from collections import defaultdict
import json
import random

### this is an old script borrowed from somewhere
## TODO: Remove this.
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

def create_mnist_like_dataset(data_path: str, m: int, d:int):
    data_path = Path(data_path)
    train_path = data_path / Path("train-images-idx3-ubyte.gz")
    test_path = data_path / Path("t10k-images-idx3-ubyte.gz")

    
    import gzip
    with gzip.open(train_path, "rb") as f:
        train_images = np.frombuffer(f.read(), dtype=np.uint8, offset=16).reshape(-1, 784).astype(np.float32)
        
    with gzip.open(test_path, "rb") as f:
        test_images = np.frombuffer(f.read(), dtype=np.uint8, offset=16).reshape(-1, 784).astype(np.float32)



    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    train_images = scaler.fit_transform(train_images)
    test_images = scaler.transform(test_images)

    ## Reduce dimension to 512
    pca = PCA(n_components=d)
    train_images = pca.fit_transform(train_images)
    test_images = pca.transform(test_images)


    ## Remove half the classes
    train_label_path = data_path / Path("train-labels-idx1-ubyte.gz")
    test_label_path = data_path / Path("t10k-labels-idx1-ubyte.gz")


    with gzip.open(train_label_path, 'rb') as f:
        train_labels = np.frombuffer(f.read(), dtype=np.uint8, offset=8).astype(int)
    with gzip.open(test_label_path, 'rb') as f:
        test_labels = np.frombuffer(f.read(), dtype=np.uint8, offset=8).astype(int)

    train_images  = train_images[train_labels > 4]
    test_images = test_images[test_labels > 4]

    client_data_train = np.array_split(train_images, m)
    client_data_test = np.array_split(test_images, m)
    
    client_data = {client_id: {"train":train_data, "test": test_data} 
                   for client_id, (train_data, test_data) in enumerate(zip(client_data_train, client_data_test))}
    return client_data

def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda: None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith(".json")]
    for f in files:
        file_path = os.path.join(data_dir, f)
        with open(file_path, "r") as inf:
            cdata = json.load(inf)
        clients.extend(cdata["users"])
        if "hierarchies" in cdata:
            groups.extend(cdata["hierarchies"])
        data.update(cdata["user_data"])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(data_dir: str):
    """parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users

    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    """
    train_dir = Path(data_dir, "train")
    test_dir = Path(data_dir, "test")
    train_clients, train_groups, train_data = read_dir(train_dir)
    test_clients, test_groups, test_data = read_dir(test_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


def create_femnist(m:int,  data_path:str, d:int):
    all_client_idx, _, train_data, test_data = read_data(data_path)
    assert len(all_client_idx) >= m
    client_idx = random.sample(all_client_idx, m)
    

    train_chunks, test_chunks = [],[]
    for client_id in client_idx:
        ## Shakespeare has specific data formatting.
        train_features = np.array(train_data[client_id]['x'])
        train_features = train_features.reshape(len(train_features),-1)
        train_labels = np.array(train_data[client_id]['y'])
        train_features = train_features[train_labels > 4]
        test_features = np.array(test_data[client_id]['x'])
        test_features = test_features.reshape(len(test_features), -1)
        test_labels = np.array(test_data[client_id]['y'])
        test_features = test_features[test_labels>4]
        train_chunks.append(train_features)
        test_chunks.append(train_features)
    
    train_images = np.concatenate(train_chunks, axis=0)
    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    scaler.fit(train_images)
    train_chunks = [scaler.transform(chunk) for chunk in train_chunks]
    test_chunks = [scaler.transform(chunk) for chunk in test_chunks]

    ## Reduce dimension to 512
    pca = PCA(n_components=d)
    pca.fit(train_images)
    train_chunks = [pca.transform(chunk) for chunk in train_chunks]
    test_chunks = [pca.transform(chunk) for chunk in test_chunks]
    client_data = {i: {"train": train_chunks[i], "test": test_chunks[i]} for i in range(m)}
    return client_data

def create_client_datasets(config:dict, task:str):
    if task == "linreg":
        return create_client_datasets_linreg(config=config)
    elif task == "kmeans":
        return create_client_datasets_kmeans(config=config)
    elif task == "power_iter":
        return create_client_datasets_kmeans(config=config)
    else:
        raise ValueError(f"Task {task} has no datasets.")   


def create_har(m:int, data_path: str, d: int):
    train_features = np.loadtxt(Path(data_path)/ "train"/ "X_train.txt")
    test_features = np.loadtxt(Path(data_path) / "test"/ "X_test.txt")
    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    train_features = scaler.fit_transform(train_features)
    test_features = scaler.transform(test_features)  

    ## Reduce dimension to 512
    pca = PCA(n_components=d)
    train_features = pca.fit_transform(train_features)
    test_features = pca.transform(test_features)

    client_data_train = np.array_split(train_features, m)
    client_data_test = np.array_split(test_features, m)
    
    client_data = {client_id: {"train":train_data, "test": test_data} 
                   for client_id, (train_data, test_data) in enumerate(zip(client_data_train, client_data_test))}

    return client_data


### KMeans and Power Iteration can use the same dataset as both of them require
### only features and no labels.
def create_client_datasets_kmeans(config:dict):
    name = config["name"]
    if name in ["mnist", "fashionmnist"]:
        return create_mnist_like_dataset(**config["params"])
    elif name  == "femnist":
        return create_femnist(**config["params"])
    elif name == "har":
        return create_har(**config["params"])
    else:
        raise ValueError(f"Dataset {name} not a valid KMeans dataset.")
    

def create_client_datasets_linreg(
    config : dict):
    # dataset_type:str,
    # m:int = 1000, 
    # params:dict = {"d":1000, "n":100, "het":10, "noise_var":0.01}):
    # Generate client data for linear regression

    name = config["name"]    
    if name == "synthetic":
        return create_synthetic_dataset_linreg(**config["params"])
    elif name == "ujindoorloc":
        return create_ujindoorloc_linreg(**config["params"])
    else:
        return ValueError(f"Invalid dataset for linear regression :  {name}")    





def create_synthetic_dataset_linreg(m:int , d:int, n:int, het: float, noise_var: float, B:float): 

    true_w = np.random.normal(0,1, size=(d))
    true_w = B*true_w/np.linalg.norm(true_w)   
    client_data = {}
    # For each client the data is generated as y = <w_i, x> + noise
    # || w_i - w_j ||^2 < het
    # x and noise from same distribution
    for client_id in range(m):
        w_i = true_w + np.random.normal(0,1,size=(d))*het
        # w_i = np.random.binomial(1, 0.5, size=(d,)) * het
        # Generate w*
        X_i = np.random.normal(0, 1.0, size=(2*n, d))
        noise_i = np.random.normal(0, noise_var, size= (2*n,))
        y_i = X_i @ w_i + noise_i
        client_data[client_id] = {"train": (X_i[:n,:], y_i[:n]), "test": (X_i[n:,:], y_i[n:])}
    return client_data


def create_ujindoorloc_linreg(m : int, d:int, data_path:str):
    
    import pandas as pd
    
    data_path = Path(data_path)
    
    train_path = data_path / Path('trainingData.csv')
    test_path = data_path / Path('validationData.csv')
    
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)
    
    # First 520 columns are feature values.
    # 521st column is the Longitude is the label.
    # We only keep the first 512 features.
    train_data = train_df.values[:, list(range(d)) + [520]]
    test_data = test_df.values[:, list(range(d)) + [520]]
    
    
    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    train_data = scaler.fit_transform(train_data)
    test_data = scaler.transform(test_data)
    client_data_train = np.array_split(train_data, m)
    client_data_test = np.array_split(train_data, m)
    
    client_data = {client_id : {"train" : (train_data[:,:-1], train_data[:, -1]), "test" : (test_data[:,:-1], test_data[:, -1])} 
                   for client_id, (train_data, test_data) in enumerate(zip(client_data_train, client_data_test))}
    return client_data


    