import os
import struct
import json
import random
from argparse import ArgumentParser
import numpy as np
from itertools import accumulate 
from PIL import Image

def bias_sample_synthesis(images, bck):
    n = images.shape[0]
    bck_mask = images <= 0.5
    syn_img = np.zeros((n, 28, 28, 3), dtype=np.float32)
    # syn_img[bck_mask] = np.array(bck)
    # syn_img[bck_mask!=1] = 255
    syn_img[bck_mask] = 0
    syn_img[bck_mask!=1] = np.array(bck)
    return syn_img

def gen_bias_dataset(
    mnist_path=r"./dataset",
    # corr=[0.99]*10,
    # density=[1]*10,
    corr=[0.99]*10,
    density=[1]*10,
    bias=[0,1,2,3,4,5,6,7,8,9],
    split="train",
    target_path="./data/SCMnist",
    valid_ratio=0.2
):
    if any([corr[i]*density[i]>1.0 for i in range(len(corr))]):
        raise ValueError("Correlation times density should be no larger than 1.0. The conditions could not be satisfied simultaneously otherwise.")
    elif any([corr[i]>1 for i in range(len(corr))]):
        raise ValueError("Correlation should be no larger than 1.")
    # load original mnist data
    if split in ["train", "valid"]:
        images = np.load(os.path.join(mnist_path, "train-images.npy"))
        labels = np.load(os.path.join(mnist_path, "train-labels.npy"))
        # split train valid set
        train_size = int(labels.shape[0] * (1-valid_ratio))
        if split == "train":
            images = images[:train_size]
            labels = labels[:train_size]
        else:
            images = images[train_size:]
            labels = labels[train_size:]
    elif split == "test":
        images = np.load(os.path.join(mnist_path, "test-images.npy"))
        labels = np.load(os.path.join(mnist_path, "test-labels.npy"))
    index = np.array(range(images.shape[0]))

    # pad_images = np.zeros(
    #     (images.shape[0], images.shape[1] +
    #         4, images.shape[2] + 4),
    #     dtype=images.dtype,
    # )
    # pad_images[
    # :, 4: 4 + images.shape[1], 4: 4 + images.shape[1]
    # ] = images
    # # pad_images[:, 2: 2 + images.shape[1],
    # #            2: 2 + images.shape[1]] = images
    # images = pad_images
    
    # bias settings
    # bias_feats = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
    bias_feats = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128), (255, 255, 0), (255, 0, 255), (0, 255, 255), (255, 255, 255)]
    # bias_feat = bias_feats[bias[0]]
    # bias_cls = bias[1]
    
    # bias dataset statistics
    n_classes = 10
    n_feats = len(bias)
    n_bias = len([1 for i in range(n_feats) if bias[i] is not None])
    n = images.shape[0]

    # unique_labels：
    # counts：
    unique_labels, counts = np.unique(labels, return_counts=True)
    # 
    proportions = counts/n
    # 
    cls_dis = dict(zip(unique_labels, proportions))

    p_bias_feat_ls = []
    n_bias_feat_ls = []
    n_ba_ls = []
    n_bc_ls = []
    for i in range(n_feats):
        target_cls = bias[i]
        # skip unbiased feature
        if target_cls is None:
            p_bias_feat_ls.append(None)
            n_bias_feat_ls.append(None)
            n_ba_ls.append(None)
            n_bc_ls.append(None)
            continue
        p_bias_feat = cls_dis[target_cls] * density[i]
        n_bias_feat = int(n * p_bias_feat)
        n_ba = int(n * p_bias_feat * corr[i])
        n_bc = n_bias_feat - n_ba
        p_bias_feat_ls.append(p_bias_feat)
        n_bias_feat_ls.append(n_bias_feat)
        n_ba_ls.append(n_ba)
        n_bc_ls.append(n_bc)

    dataset = [[None for _ in range(n_feats)] for _ in range(n_classes)]
        
    # assign color to samples according to statistic
    # n_bias_feat_in_unbias_cls = int((n_bias_feat-n_ba)/(n_classes-1))
    for cls in range(n_classes):
        bias_feat = bias.index(cls) if cls in bias else None
        idx_cls = index[labels == cls]
        n_cls = idx_cls.shape[0]
        n_feat_ls = [0] * n_feats
        # for bias feats, evenly distribute bc in other classes
        for i in range(n_feats):
            if i == bias_feat:
                n_feat_ls[i] = n_ba_ls[i]
            else:
                n_feat_ls[i] = int(n_bc_ls[i]/(n_classes-1)) if bias[i] is not None else 0
        # for unbiased feats(if any), take up the remaining samples
        if n_feats - n_bias > 0:
            n_remain = n_cls - np.sum(n_feat_ls)
            n_unbias_feat = int(n_remain / (n_feats - n_bias))
            for i in range(n_feats):
                n_feat_ls[i] = n_unbias_feat if bias[i] is None else n_feat_ls[i]
        # make sure sum of n_feat_ls = n_cls
        diff = n_cls - np.sum(n_feat_ls)
        sign = 1 if diff > 0 else -1
        i = 0
        while diff != 0:
            n_feat_ls[i] += sign
            diff -= sign
            i = (i+1) % n_feats
        assert np.sum(n_feat_ls) == n_cls

        idx_feat_ls = [idx_cls[x-y:x] for x, y in zip(accumulate(n_feat_ls), n_feat_ls)]
        p_feat_ls = [n_feat/n for n_feat in n_feat_ls]
        print(f"for class {cls}, the distribution of features: {n_feat_ls}")
        # continue
        # synthesis samples
        for i in range(len(bias_feats)):
            img = bias_sample_synthesis(images[idx_feat_ls[i]], bias_feats[i])
            dataset[cls][i] = (img, idx_feat_ls[i])
            continue
            test = img[0].astype(np.uint8)
            # test = test.transpose(1,2,0)
            path = f"plots/temp/cmnist_test_{cls}_{i}.jpeg"
            test = Image.fromarray(test)
            test.save(path)
            n_ = img.shape[0]
            print(f"save to {path}, {n_} samples in total")
    # return
    # store synthesized samples
    split_path = os.path.join(target_path, split)
    os.makedirs(split_path, exist_ok=True)
    for cls in range(n_classes):
        digit_path = os.path.join(split_path, str(cls))
        os.makedirs(digit_path, exist_ok=True)
        for i in range(len(bias_feats)):
            imgs, idx = dataset[cls][i]
            print(f"for digit {cls} and color {i}, samples in total {imgs.shape[0]}")
            for j in range(imgs.shape[0]):
                img = imgs[j].astype(np.uint8)
                path = os.path.join(digit_path, f"{cls}_{i}_{idx[j]}.jpeg")
                img = Image.fromarray(img)
                img.save(path)
            
def generate_train_set(root, densities, corrs, bias):
    # generate train set with various sparsity and correlation
    for density in densities:
        for corr in corrs:
            target_path = os.path.join(root, f"{density}pct", f"{corr}pct")
            os.makedirs(target_path, exist_ok=True)
            print(f"=== generating training data with density: {density}, corr: {corr}, in path: {target_path} ===")
            gen_bias_dataset(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="train")
            gen_bias_dataset(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="valid")
            

if __name__ == "__main__":
    # generate training set for 10 biases LMLP, HMHP, Unbiased
    n_b = 10
    bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
    root = f"./dataset/CMNIST-{n_b}"
    densities = [1]
    corrs = [0.98, 0.5, 0.1]
    print("===== generating train and valid set for 10 biases =====")
    generate_train_set(root=root, densities=densities, corrs=corrs, bias=bias)
    # generate test set
    print("===== generating test set =====")
    gen_bias_dataset(corr=[1/10]*10, density=[1]*10, bias=[None]*10, target_path=root, split="test")

    # generate training set for 1 biases HMLP
    n_b = 1
    bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
    root = f"./dataset/CMNIST-{n_b}"
    densities = [1]
    corrs = [0.98]
    print("===== generating train and valid set for 1 biases =====")
    generate_train_set(root=root, densities=densities, corrs=corrs, bias=bias)
    # generate test set
    print("===== generating test set =====")
    gen_bias_dataset(corr=[1/10]*10, density=[1]*10, bias=[None]*10, target_path=root, split="test")
    
    
