import os
import struct
import json
import random
from argparse import ArgumentParser
import numpy as np
from itertools import accumulate
from PIL import Image
from torchvision.datasets import CIFAR10
import pickle
from corrupted_cifar10_protocol import CORRUPTED_CIFAR10_PROTOCOL
import torchvision.transforms as T
import torch
# General Python
random.seed(42)

# NumPy
np.random.seed(42)

# PyTorch
torch.manual_seed(42)

def bias_sample_synthesis(images, corrupted, syn):
    corrupted_imgs = []
    convert_img = T.Compose([T.ToTensor(), T.ToPILImage()])
    protocol = CORRUPTED_CIFAR10_PROTOCOL
    for image in images:
        width = image.shape[1]
        corrupted_img = protocol[corrupted](convert_img(image), severity=4)
        if syn == "left": # recover the right part
            corrupted_img = np.concatenate([
                np.array(corrupted_img)[:, :width//2, :], 
                np.array(image)[:, width//2:, :]
            ], axis=1)
        if syn == "right": # recover the right part
            corrupted_img = np.concatenate([
                np.array(image)[:, :width//2, :], 
                np.array(corrupted_img)[:, width//2:, :]
            ], axis=1)
        corrupted_imgs.append(np.array(corrupted_img).astype(np.uint8))
    corrupted_imgs = np.array(corrupted_imgs)
    return corrupted_imgs


def gen_bias_dataset(
    # cifar_path=r"./dataset/CIFAR10",
    # corr=[0.99]*10,
    # density=[1]*10,
    images,
    labels,
    corr=[0.99]*10,
    density=[1]*10,
    bias=[0,1,2,3,4,5,6,7,8,9],
    split="train",
    target_path="./data/Cifar",
    valid_ratio=0.2,
    syn="full",
    bias_feats=[]
):
    if any([corr[i]*density[i]>1.0 for i in range(len(corr))]):
        raise ValueError("Correlation times density should be no larger than 1.0. The conditions could not be satisfied simultaneously otherwise.")
    elif any([corr[i]>1 for i in range(len(corr))]):
        raise ValueError("Correlation should be no larger than 1.")
    # load original mnist data

    index = np.array(range(images.shape[0]))
    labels = np.array(labels)
    
    ################################
    # bias dataset statistics
    n_classes = 10
    n_feats = len(bias)
    n_bias = len([1 for i in range(n_feats) if bias[i] is not None])
    n = images.shape[0]

    # unique_labels：
    # counts：
    unique_labels, counts = np.unique(labels, return_counts=True)
    # 
    proportions = counts/n
    # 
    cls_dis = dict(zip(unique_labels, proportions))

    p_bias_feat_ls = []
    n_bias_feat_ls = []
    n_ba_ls = []
    n_bc_ls = []
    for i in range(n_feats):
        target_cls = bias[i]
        # skip unbiased feature
        if target_cls is None:
            p_bias_feat_ls.append(None)
            n_bias_feat_ls.append(None)
            n_ba_ls.append(None)
            n_bc_ls.append(None)
            continue
        p_bias_feat = cls_dis[target_cls] * density[i]
        n_bias_feat = int(n * p_bias_feat)
        n_ba = int(n * p_bias_feat * corr[i])
        n_bc = n_bias_feat - n_ba
        p_bias_feat_ls.append(p_bias_feat)
        n_bias_feat_ls.append(n_bias_feat)
        n_ba_ls.append(n_ba)
        n_bc_ls.append(n_bc)

    dataset = [[None for _ in range(n_feats)] for _ in range(n_classes)]
    images_ = np.zeros_like(images)
    labels_ = np.zeros_like(labels)
    attrs_ = np.zeros_like(labels)
        
    # assign color to samples according to statistic
    # n_bias_feat_in_unbias_cls = int((n_bias_feat-n_ba)/(n_classes-1))
    for cls in range(n_classes):
        bias_feat = bias.index(cls) if cls in bias else None
        idx_cls = index[labels == cls]
        n_cls = idx_cls.shape[0]
        n_feat_ls = [0] * n_feats
        # for bias feats, evenly distribute bc in other classes
        for i in range(n_feats):
            if i == bias_feat:
                n_feat_ls[i] = n_ba_ls[i]
            else:
                n_feat_ls[i] = int(n_bc_ls[i]/(n_classes-1)) if bias[i] is not None else 0
        # for unbiased feats(if any), take up the remaining samples
        if n_feats - n_bias > 0:
            n_remain = n_cls - np.sum(n_feat_ls)
            n_unbias_feat = int(n_remain / (n_feats - n_bias))
            for i in range(n_feats):
                n_feat_ls[i] = n_unbias_feat if bias[i] is None else n_feat_ls[i]
        # make sure sum of n_feat_ls = n_cls
        diff = n_cls - np.sum(n_feat_ls)
        sign = 1 if diff > 0 else -1
        i = 0
        while diff != 0:
            n_feat_ls[i] += sign
            diff -= sign
            i = (i+1) % n_feats
        assert np.sum(n_feat_ls) == n_cls

        idx_feat_ls = [idx_cls[x-y:x] for x, y in zip(accumulate(n_feat_ls), n_feat_ls)]
        p_feat_ls = [n_feat/n for n_feat in n_feat_ls]
        print(f"for class {cls}, the distribution of features: {n_feat_ls}")
        # continue
        # synthesis samples
        for i in range(len(bias_feats)):
            img = bias_sample_synthesis(images[idx_feat_ls[i]], bias_feats[i], syn)
            idx = idx_feat_ls[i]
            dataset[cls][i] = (img, idx)
            images_[idx] = img
            labels_[idx] = cls
            attrs_[idx] = i
            continue
            ori = images[idx_feat_ls[i]][0]
            path = f"result/temp/test_{cls}_{bias_feats[i]}_ori.jpeg"
            ori = Image.fromarray(ori)
            ori.save(path)
            
            
            test = img[0].astype(np.uint8)
            # test = test.transpose(1,2,0)
            path = f"result/temp/test_{cls}_{bias_feats[i]}.jpeg"
            test = Image.fromarray(test)
            test.save(path)
            n_ = len(img)
            print(f"save to {path}, {n_} samples in total")
    # return
    return images_, labels_, attrs_
    


def generate_train_set(root, densities, corrs, bias):
    # generate train set with various sparsity and correlation
    for density in densities:
        for corr in corrs:
            target_path = os.path.join(root, f"{density}pct", f"{corr}pct")
            os.makedirs(target_path, exist_ok=True)
            print(f"=== generating training data with density: {density}, corr: {corr}, in path: {target_path} ===")
            gen_bias_dataset(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="train")
            gen_bias_dataset(corr=[corr]*10, density=[density]*10, bias=bias, target_path=target_path, split="valid")

def gen_multi_bias_dataset(
    cifar_path=r"./dataset/CIFAR10",
    corr1=[0.99]*10,
    corr2=[0.99]*10,
    density=[1]*10,
    bias1=[0,1,2,3,4,5,6,7,8,9],
    bias2=[0,1,2,3,4,5,6,7,8,9],
    split="train",
    target_path="./data/Cifar",
    valid_ratio=0.2
):
    if split in ["train", "valid"]:
        dataset = CIFAR10(cifar_path, train=True)
        images = dataset.data
        labels = dataset.targets


        train_size = int(images.shape[0] * (1 - valid_ratio))
        if split == "train":
            images = images[:train_size]
            labels = labels[:train_size]
        else:
            images = images[train_size:]
            labels = labels[train_size:]
    elif split == "test":
        dataset = CIFAR10(cifar_path, train=False)
        images = dataset.data
        labels = dataset.targets
        
    bias_feats_1 = ["Snow", "Frost", "Fog", "Brightness", "Contrast",
                  "Spatter", "Elastic", "JPEG", "Pixelate", "Saturate"]

    bias_feats_2 = ["Gaussian Noise", "Shot Noise", "Impulse Noise", "Speckle Noise", "Gaussian Blur",
                  "Defocus Blur", "Glass Blur", "Motion Blur", "Zoom Blur", "Original"]
    
    # image = images[0:1]
    # path_o = f"./temp/o.jpeg"
    # img = Image.fromarray(image[0]) #? test
    # img.save(path_o)
    # for i in bias_feats_1:
    #     for j in bias_feats_2:
    #         path_l = f"./temp/{i}-{j}_l.jpeg"
    #         path_r = f"./temp/{i}-{j}_r.jpeg"
    #         c = bias_sample_synthesis(image, i, "left")
    #         img = Image.fromarray(c[0]) #? test
    #         img.save(path_l)
    #         c = bias_sample_synthesis(image, j, "right")
    #         img = Image.fromarray(c[0]) #? test
    #         img.save(path_r)
    # return
    
    idxs = np.arange(images.shape[0], dtype=int)
    images, labels, attrs1 = gen_bias_dataset(images, labels, corr=corr1, density=density, bias=bias1, target_path=target_path, split="train", syn="left", bias_feats=bias_feats_1) 
    
    img = Image.fromarray(images[0]) #? test
    img.save("./test_left.png")
    
    #! probably some shuffling here
    # Shuffle the dataset
    shuffled_indices = np.random.permutation(images.shape[0])
    images = images[shuffled_indices]
    labels = labels[shuffled_indices]
    attrs1 = attrs1[shuffled_indices]
    idxs = idxs[shuffled_indices]
    
    images, labels, attrs2 = gen_bias_dataset(images, labels, corr=corr2, density=density, bias=bias2, target_path=target_path, split="train", syn="right", bias_feats=bias_feats_2)
    
    img = Image.fromarray(images[shuffled_indices == 0][0]) #? test
    img.save("./test_right.png")
    
    # store synthesized samples
    split_path = os.path.join(target_path, split)
    os.makedirs(split_path, exist_ok=True)
    for cls in range(10):
        digit_path = os.path.join(split_path, str(cls))
        os.makedirs(digit_path, exist_ok=True)
            
    for j in range(images.shape[0]):
        img = images[j].astype(np.uint8)
        path = os.path.join(split_path, str(labels[j]), f"{labels[j]}_{attrs1[j]}_{attrs2[j]}_{idxs[j]}.jpeg")
        print(path)
        # continue
        img = Image.fromarray(img)
        img.save(path)
    


if __name__ == "__main__":
    
    # generate training set for 10 biases LMLP, HMHP, Unbiased
    n_b_1 = 10
    bias1 = [i for i in range(n_b_1)] + [None for _ in range(10 - n_b_1)]
    n_b_2 = 1
    bias2 = [i for i in range(n_b_2)] + [None for _ in range(10 - n_b_2)]
    root = f"./dataset/Cifar10C-MB-{n_b_1}-{n_b_2}"
    # densities = [1]
    # corrs = [0.98, 0.5, 0.1]
    density = 1
    corr1 = 0.5
    corr2 = 0.98
    # target path
    target_path = os.path.join(root, f"{density}pct", f"{corr1}pct", f"{corr2}pct")
    os.makedirs(target_path, exist_ok=True)
    
    # arguments for synthesis
    
    
    print("===== generating train and valid set for 10 biases =====")
    # generate_train_set(root=root, densities=densities, corrs=corrs, bias=bias)
    print(f"=== generating training data with density: {density}, corr1: {corr1}, corr2: {corr2}, in path: {target_path} ===")
    
    gen_multi_bias_dataset(corr1=[corr1]*10, corr2=[corr2]*10, density=[density]*10, bias1=bias1, bias2=bias2, target_path=target_path, split="train")
    gen_multi_bias_dataset(corr1=[corr1]*10, corr2=[corr2]*10, density=[density]*10, bias1=bias1, bias2=bias2, target_path=target_path, split="valid")
    # exit()
    # generate test set
    print("===== generating test set =====")
    gen_multi_bias_dataset(corr1=[1/10]*10, corr2=[1/10]*10, density=[1]*10, bias1=[None]*10, bias2=[None]*10, target_path=root, split="test")

    exit()
    # generate training set for 1 biases HMLP
    n_b = 1
    bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
    root = f"./dataset/Cifar10C-{n_b}"
    densities = [1]
    corrs = [0.98]
    print("===== generating train and valid set for 1 biases =====")
    generate_train_set(root=root, densities=densities, corrs=corrs, bias=bias)
    # generate test set
    print("===== generating test set =====")
    gen_bias_dataset(corr=[1/10]*10, density=[1]*10, bias=[None]*10, target_path=root, split="test")


