#!/usr/bin/env python
# -*-coding:utf-8 -*-
import os
import numpy as np 
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch 
import data.prepare_data as prepare_data 


def _get_mnist(conf, root, split, transform, target_transform, download):
    is_train = True if "train" in split else False 
    transform_train = [transforms.ToTensor()]
    normalize = (
        transforms.Normalize((0.1307,), (0.3081,)) if conf.pn_normalize else None
    )
    if is_train:
        transform = transforms.Compose(
            transform_train 
            + ([normalize] if normalize is not None else [])
        )
    else:
        transform = transforms.Compose([transforms.ToTensor()] + ([normalize] if normalize is not None else []))

    return datasets.MNIST(
        root=root,
        train=is_train,
        transform=transform,
        target_transform=target_transform,
        download=download,
    )
    
    
def get_mnist_transform(conf):
    normalize = (
        transforms.Normalize((0.1307,), (0.3081,)) if conf.pn_normalize else None
    )
    transform = transforms.Compose([transforms.ToTensor()] + ([normalize] if normalize is not None else []))
    return transform, transform

    
def get_dataset(conf, name, datasets_path, split="train", transform=None, target_transform=None,
                download=True, load_opt=None):
    """Args:
    conf: the configuration class 
    name: str, cifar10/cifar100 
    datasets_path: the location to save/load the dataset 
    split: "train" / "test" 
    transform: the data augmentation for training  
    target_transform: the data augmentation for testing 
    download: bool variable
    """
    # create data folder if it does not exist.
    root = os.path.join(datasets_path, name)
    return _get_mnist(conf, root, split, transform, target_transform, download)


def get_synthetic_data_from_diffusion(conf, version=1):
    if conf.align_data == "add_fake_diffusion_sync":
        data_group = np.load("../image_dataset/UNet_mnist-250-sampling_steps-50000_images-class_condn_True.npz")        
    if "local_generator" not in conf.align_data:
        im_group = data_group["arr_0"]
        label_group = data_group["arr_1"][:len(im_group)]

    data_index = np.arange(len(im_group))
    split_index = np.split(data_index, conf.n_clients)
    num_images = len(im_group)

    if conf.num_synthetic_images * conf.n_clients < num_images:
        cls_, cls_freq = np.unique(label_group, return_counts=True)
        sub_set_index = []
        for i, s_cls in enumerate(cls_):
            _sub = np.where(label_group == s_cls)[0]
            sub_set_index.append(_sub[:conf.num_synthetic_images])
            # sub_set_index.append(np.random.choice(_sub, conf.num_synthetic_images, replace=False))
        sub_set_index = np.concatenate(sub_set_index, axis=0).astype(np.int32)
        im_group = im_group[sub_set_index]
        label_group = label_group[sub_set_index]

    data_index = np.arange(len(im_group))
    split_index = np.split(data_index, conf.n_clients)
    num_images = len(im_group)
    
    if conf.random_shuffle == True:
        print("shuffle the data index")
        shuffle_index = np.random.choice(data_index, len(data_index), replace=False)
        im_group = im_group[shuffle_index]
        label_group = label_group[shuffle_index]
    
    im_per_client = [im_group[v] for v in split_index]
    # im_per_client = [np.transpose(v, (0, 3, 1, 2)) / 255.0 for v in im_per_client]
    im_per_client = [(v / 255.0).astype(np.float32) for v in im_per_client]
    cls_per_client = [label_group[v] for v in split_index]
    print([np.shape(v) for v in im_per_client])
    print([np.max(v) for v in im_per_client])
    print([np.shape(v) for v in cls_per_client])
    return im_per_client, cls_per_client 


