import os
import struct
import json
import random
from argparse import ArgumentParser
import numpy as np
from itertools import accumulate
from PIL import Image
from torchvision.datasets import CIFAR10
import pickle
from corrupted_cifar10_protocol import CORRUPTED_CIFAR10_PROTOCOL
import torchvision.transforms as T


def bias_sample_synthesis(images, corrupted):
    corrupted_imgs = []
    convert_img = T.Compose([T.ToTensor(), T.ToPILImage()])
    protocol = CORRUPTED_CIFAR10_PROTOCOL
    for image in images:
        corrupted_img = protocol[corrupted](convert_img(image), severity=4)
        corrupted_imgs.append(np.array(corrupted_img).astype(np.uint8))
    corrupted_imgs = np.array(corrupted_imgs)
    return corrupted_imgs


def gen_sparse_bias_cmnist(
    cifar_path=r"./dataset/CIFAR10",
    # corr=[0.99]*10,
    # density=[1]*10,
    corr=[0.99]*10,
    density=[1]*10,
    bias=[0,1,2,3,4,5,6,7,8,9],
    split="train",
    target_path="./data/Cifar",
    valid_ratio=0.2
):
    if any([corr[i]*density[i]>1.0 for i in range(len(corr))]):
        raise ValueError("Correlation times density should be no larger than 1.0. The conditions could not be satisfied simultaneously otherwise.")
    elif any([corr[i]>1 for i in range(len(corr))]):
        raise ValueError("Correlation should be no larger than 1.")
    # load original mnist data
    if split in ["train", "valid"]:
        dataset = CIFAR10(cifar_path, train=(split == "train"))
        images = dataset.data
        labels = dataset.targets


        train_size = int(images.shape[0] * (1 - valid_ratio))
        if split == "train":
            images = images[:train_size]
            labels = labels[:train_size]
            # images = images[:2000]
            # labels = labels[:2000]
        else:
            images = images[train_size:]
            labels = labels[train_size:]
            # images = images[2000:3000]
            # labels = labels[2000:3000]
    elif split == "test":
        dataset = CIFAR10(cifar_path, train=False)
        images = dataset.data
        labels = dataset.targets
    index = np.array(range(images.shape[0]))
    labels = np.array(labels)

    # images = images.reshape(-1, 3, 32, 32)  # 

    # bias settings
    bias_feats = ["Snow", "Frost", "Fog", "Brightness", "Contrast",
                  "Spatter", "Elastic", "JPEG", "Pixelate", "Saturate"]

    # bias_feats = ["Original",  "Pixelate", "Saturate"]

    # bias_feats = ["GaussianNoise", "ShotNoise", "ImpulseNoise", "SpeckleNoise", "GaussianBlur",
    #               "DefocusBlur", "GlassBlur", "MotionBlur", "ZoomBlur", "Original"]
    
    ################################
    # bias dataset statistics
    n_classes = 10
    n_feats = len(bias)
    n_bias = len([1 for i in range(n_feats) if bias[i] is not None])
    n = images.shape[0]

    # unique_labels：
    # counts：
    unique_labels, counts = np.unique(labels, return_counts=True)
    # 
    proportions = counts/n
    # 
    cls_dis = dict(zip(unique_labels, proportions))

    p_bias_feat_ls = []
    n_bias_feat_ls = []
    n_ba_ls = []
    n_bc_ls = []
    for i in range(n_feats):
        target_cls = bias[i]
        # skip unbiased feature
        if target_cls is None:
            p_bias_feat_ls.append(None)
            n_bias_feat_ls.append(None)
            n_ba_ls.append(None)
            n_bc_ls.append(None)
            continue
        p_bias_feat = cls_dis[target_cls] * density[i]
        n_bias_feat = int(n * p_bias_feat)
        n_ba = int(n * p_bias_feat * corr[i])
        n_bc = n_bias_feat - n_ba
        p_bias_feat_ls.append(p_bias_feat)
        n_bias_feat_ls.append(n_bias_feat)
        n_ba_ls.append(n_ba)
        n_bc_ls.append(n_bc)

    dataset = [[None for _ in range(n_feats)] for _ in range(n_classes)]
        
    # assign color to samples according to statistic
    # n_bias_feat_in_unbias_cls = int((n_bias_feat-n_ba)/(n_classes-1))
    for cls in range(n_classes):
        bias_feat = bias.index(cls) if cls in bias else None
        idx_cls = index[labels == cls]
        n_cls = idx_cls.shape[0]
        n_feat_ls = [0] * n_feats
        # for bias feats, evenly distribute bc in other classes
        for i in range(n_feats):
            if i == bias_feat:
                n_feat_ls[i] = n_ba_ls[i]
            else:
                n_feat_ls[i] = int(n_bc_ls[i]/(n_classes-1)) if bias[i] is not None else 0
        # for unbiased feats(if any), take up the remaining samples
        if n_feats - n_bias > 0:
            n_remain = n_cls - np.sum(n_feat_ls)
            n_unbias_feat = int(n_remain / (n_feats - n_bias))
            for i in range(n_feats):
                n_feat_ls[i] = n_unbias_feat if bias[i] is None else n_feat_ls[i]
        # make sure sum of n_feat_ls = n_cls
        diff = n_cls - np.sum(n_feat_ls)
        sign = 1 if diff > 0 else -1
        i = 0
        while diff != 0:
            n_feat_ls[i] += sign
            diff -= sign
            i = (i+1) % n_feats
        assert np.sum(n_feat_ls) == n_cls

        idx_feat_ls = [idx_cls[x-y:x] for x, y in zip(accumulate(n_feat_ls), n_feat_ls)]
        p_feat_ls = [n_feat/n for n_feat in n_feat_ls]
        print(f"for class {cls}, the distribution of features: {n_feat_ls}")
        # continue
        # synthesis samples
        for i in range(len(bias_feats)):
            img = bias_sample_synthesis(images[idx_feat_ls[i]], bias_feats[i])
            dataset[cls][i] = (img, idx_feat_ls[i])
            continue
            ori = images[idx_feat_ls[i]][0]
            path = f"result/temp/test_{cls}_{bias_feats[i]}_ori.jpeg"
            ori = Image.fromarray(ori)
            ori.save(path)
            
            
            test = img[0].astype(np.uint8)
            # test = test.transpose(1,2,0)
            path = f"result/temp/test_{cls}_{bias_feats[i]}.jpeg"
            test = Image.fromarray(test)
            test.save(path)
            n_ = len(img)
            print(f"save to {path}, {n_} samples in total")
    # return
    # store synthesized samples
    split_path = os.path.join(target_path, split)
    os.makedirs(split_path, exist_ok=True)
    for cls in range(n_classes):
        digit_path = os.path.join(split_path, str(cls))
        os.makedirs(digit_path, exist_ok=True)
        for i in range(len(bias_feats)):
            imgs, idx = dataset[cls][i]
            print(f"for digit {cls} and color {i}, samples in total {imgs.shape[0]}")
            for j in range(imgs.shape[0]):
                img = imgs[j].astype(np.uint8)
                path = os.path.join(digit_path, f"{cls}_{i}_{idx[j]}.jpeg")
                img = Image.fromarray(img)
                img.save(path)
    ########################
    


def generate_train_set(root, densities, corrs, bias):
    # generate train set with various sparsity and correlation
    for density in densities:
        for corr in corrs:
            target_path = os.path.join(root, f"{density}pct", f"{corr}pct")
            os.makedirs(target_path, exist_ok=True)
            print(f"=== generating training data with density: {density}, corr: {corr}, in path: {target_path} ===")
            gen_sparse_bias_cmnist(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="train")
            gen_sparse_bias_cmnist(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="valid")


if __name__ == "__main__":
    # gen_sparse_bias_cmnist(corr=1 / 10, density=1, target_path="./data/corrupted_cifar", split="test")
    # gen_sparse_bias_cmnist(corr=0.99, density=1, target_path="./data/corrupted_cifar", split="train")
    # exit()
    n_b = 10
    bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
    root = f"./dataset/Cifar10C-{n_b}"
    densities = [1]
    # corrs = [0.995, 0.99, 0.98, 0.95]
    # corrs = [0.1, 0.98, 0.95]
    corrs = [0.98, 0.5, 0.1]

    # generate training set
    print("===== generating train and valid set =====")
    generate_train_set(root=root, densities=densities, corrs=corrs, bias=bias)

    # generate test set
    print("===== generating test set =====")
    gen_sparse_bias_cmnist(corr=[1/10]*10, density=[1]*10, bias=[None]*10, target_path=root, split="test")


