""" Builds a (pytorch) dataset with data constructed according to the structural equation model

U1  := Unif({0,1,2,3,4,5}) where this is 0->Red, 1->Green, 2->Blue, 3->Yellow, 4->Magenta, 5->Cyan
U2  := Unif({0,1,2}) where 0->Thin, 1->Regular, 2-> Thick
D   := The digit [0..9]
W1  := an MNIST image with (digit:D, Color:U1, Thickness:U2)
W2a := a DIGIT, W1.digit
W2b := a MODIFIED COLOR CODE, taken from U1 // 3
X   := An MNIST image, rotated 90 degrees with (digit:W2a, color: W2b, thickness:U2)
Y   := An MNIST image, reflected, with (digit: X.digit, color: U1, thickness:X.thickness)


So the way we will form this is:
1.  Sample U1, U2, D
2.  Sample W1: pick random mnist image from D, apply color:U1, thickness:U2
3a. Sample W2a: take digit D and apply massart noise to it
3b. Sample W2b: take color U1, and transform to U1 // 3 (and apply massart noise)
4.  Sample X: pick random mnist image with Digit W2a, apply color W2b, thickness U2 (randomly massart noise this, too). Rotate 90 degrees
5.  Sample Y: take X and recolor with U2. Reflect. Apply random Massart, too


(Note: we'll do all the data preparation in numpy because we gotta use morphoMNIST to do these things)
"""


import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms

import sys
sys.path.append('Morpho-MNIST')

from morphomnist import morpho, perturb
import numpy as np
from collections import defaultdict
import random
from tqdm.auto import tqdm 
import skimage.transform as skt

from multiprocessing import Process, Manager
#import threading
import argparse
import os
import pickle

# ===============================================================
# =           Main Dataset that we'll use for loading           =
# ===============================================================

class NapkinMNIST(Dataset):
    def __init__(self, pkl_loc=None, data=None):
        if data != None:
            self.data = data
        else:
            self.data = pickle.load(open(pkl_loc, 'rb'))

    def __len__(self):
        return self.data['U1'].shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return {k: v[idx] for k,v in self.data.items()}



def load_data(dataset, batchsize: int, numworkers: int) -> tuple[DataLoader, DistributedSampler]:
    sampler = DistributedSampler(dataset, shuffle=True)
    trainloader = DataLoader(dataset,
                             batch_size=batchsize,
                             num_workers=numworkers,
                             sampler=sampler,
                             drop_last=True)
    return trainloader, sampler

def transback(data):
    return data/2 + 0.5




# =============================================================
# =                Make full dataset                          =
# =============================================================

def get_base_dict(dataloc, train=True):
    """  dataloc is where the PYTORCH mnist lives
    returns a data dict like {label: np.array(images) (N,32,32), np.uint8}
    """
    xform = transforms.Compose([transforms.ToTensor(),
                                transforms.Pad(2)])
    pytorch_mnist = MNIST(dataloc, download=True, train=train, transform=xform)
    loader = DataLoader(pytorch_mnist, batch_size=1000, drop_last=False, num_workers=8)
    base_dict = defaultdict(list)

    for x,y in loader:
        for x_, y_ in zip(x,y):
            base_dict[y_.item()].append((x_ * 255).view(32,32).type(torch.uint8).numpy())
    return base_dict


########### TRANSFORMS #############
def _color(img, color_code):
    """ Takes a (32, 32) image and makes it a (3,32,32) of a specified color 
        Color code looks like {0: R, 1: B, 2: G, 
                               3: RG (yellow), 
                               4: RB (magenta), 
                               5: BG (cyan)}
    """
    zeros = np.zeros_like(img)
    arr = [zeros, zeros, zeros]
    IDX_MAP = {0: [0],
               1: [1],
               2: [2],
               3: [0,1],
               4: [0,2],
               5: [1,2]}

    for idx in IDX_MAP[color_code]:
        arr[idx] = img
    return np.stack(arr)


def _thicken(img, thicken_code):
    """ Takes a (32,32) or (3,32,32) image and thickens/thins it
        Thickening code = {0: thin, 1: regular, 2: thick}
    """
    grayscale_idx = None
    if len(img.shape) == 3: # if (3,32,32), convert back to (32,32)
        grayscale_idx = img.sum(axis=(1,2)).nonzero()[0].item()
        img = img[grayscale_idx]

    perturbation = {0: perturb.Thinning(amount=0.5),
                    1: perturb.Thickening(amount=0.2),
                    2: perturb.Thickening(amount=1.0)}[thicken_code]

    morphology = morpho.ImageMorphology(img, scale=4)

    img = morphology.downscale(perturbation(morphology))

    if grayscale_idx != None: # recolor if necessary
        return _color(img, grayscale_idx)

    return img
    

def _rotate(img, angle):
    return skt.rotate(img, angle, preserve_range=True)

def _reflectlr(img):
    if len(img.shape) == 3:
        return np.flip(img, 2)
    return np.fliplr(img)

def apply_massart(og, prob, high):
    if random.random() > prob:
        return og
    return random.randint(0, high)

############ /TRANSFORMS ##############


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

def generate_datadict(base_dict, num_samples=30_000, num_threads=8, random_seed=None, massart=0.0):
    """
    Creates the threads for the data dict:
    

    Each thread needs:
    - a list of (img, label) pairs to process
    - access to the full MNIST {label: [imgs]} dictionary
    - place to put outputs
    - thread num (for processing)
    """

    set_seed(random_seed)

    # Create list of (img, label) pairs to process
    base_list = []
    for lab, imglist in base_dict.items():
        for img in imglist:
            base_list.append((img, lab))
    random.shuffle(base_list)
    base_list = base_list[:num_samples]

    # Now create things to give to threads:
    sublists = [base_list[k::num_threads][:] for k in range(num_threads)]

    if num_threads == 1:
        thread_out = {}
        generate_datadict_thread(sublists[0], thread_out, base_dict,
                                 0, massart, random.randint(0, pow(2,31)))
        return thread_out
    #/Single thread case


    manager = Manager()
    thread_outs = [manager.dict() for _ in range(num_threads)]
    thread_seeds = [random.randint(0, pow(2, 31)) for _ in range(num_threads)]
    threads = [Process(target=generate_datadict_thread,
                       args=(sublist, thread_out, base_dict, i, massart, thread_seed))
               for (sublist, thread_out, i, thread_seed) 
               in zip(sublists, thread_outs, range(num_threads), thread_seeds)]

    [thread.start() for thread in threads]
    [thread.join() for thread in threads]
    base_out = {}
    for thread_out in thread_outs:
        for k,v in thread_out.items():
            if k not in base_out: base_out[k] = []
            base_out[k].extend(v)

    for k,v in base_out.items():
        base_out[k] = np.array(v)



    img_keys = ['W1', 'X', 'Y']
    long_keys = ['U1', 'U2', 'D', 
                 'W1_color', 'W1_thickness', 'W1_digit',
                 'W2a', 'W2b',
                 'X_color', 'X_thickness', 'X_digit',
                 'Y_color', 'Y_thickness', 'Y_digit']

    data_dict = {}

    for k in img_keys:
        data_dict[k] = torch.Tensor(base_out[k]) / 255.0 * 2 - 1.
    for k in long_keys:
        data_dict[k] = torch.Tensor(base_out[k]).long()

    for k, v in data_dict.items():
        print(k, v.shape)
    return data_dict


def generate_datadict_thread(sublist, thread_out, base_dict, thread_num, massart, thread_seed):
    set_seed(thread_seed)
    U1, U2, D = [], [], []
    W1, W1_color, W1_thickness, W1_digit = [], [], [], []
    W2a, W2b = [], []
    X, X_color, X_thickness, X_digit = [], [], [], []
    Y, Y_color, Y_thickness, Y_digit = [], [], [], []


    iterator = tqdm(sublist) if thread_num == 0 else sublist
    for (img, lab) in iterator:
        u1 = random.randint(0, 5) # color
        u2 = random.randint(0, 2) # thickness
        d = lab

        # generate w1
        w1_color = apply_massart(u1, massart, 5)
        w1_thickness = apply_massart(u2, massart, 2)
        w1_digit = d

        w1 = _color(_thicken(img, w1_thickness), w1_color)

        # generate w2
        w2a = apply_massart(w1_digit, massart, 9)
        w2b = apply_massart(w1_color // 3, massart, 1)

        # Generate X
        x_digit = apply_massart(w2a, massart, 9)
        x_base = random.choice(base_dict[x_digit])
        x_thickness = apply_massart(u2, massart, 2)
        x_color = apply_massart(w2b, massart, 1)
        x_nocolor = _thicken(x_base, x_thickness)
        x = _color(x_nocolor, x_color)          

        # Generate y
        y_digit = apply_massart(x_digit, massart, 9)
        if y_digit == x_digit:
            y_base = x_base
        else:
            y_base = random.choice(base_dict[y_digit])

        y_thickness = apply_massart(x_thickness, massart, 2)
        y_nocolor = _thicken(y_base, y_thickness)
        y_color = apply_massart(u1, massart, 5)
        y = _color(y_nocolor, y_color)


        U1.append(u1)
        U2.append(u2)
        D.append(d)

        W1.append(w1)
        W1_color.append(w1_color)
        W1_thickness.append(w1_thickness)
        W1_digit.append(w1_digit)

        W2a.append(w2a)
        W2b.append(w2b)

        X.append(x)
        X_color.append(x_color)
        X_thickness.append(x_thickness)
        X_digit.append(x_digit)

        Y.append(y)
        Y_color.append(y_color)
        Y_thickness.append(y_thickness)
        Y_digit.append(y_digit)


    for k in ['U1', 'U2', 'D', 
              'W1', 'W1_color', 'W1_thickness', 'W1_digit', 
              'W2a', 'W2b', 
              'X', 'X_color', 'X_thickness', 'X_digit',
              'Y', 'Y_color', 'Y_thickness', 'Y_digit']:
        thread_out[k] = eval(k)
    return
        


# =======================================================
# =           Main scripting block                      =
# =======================================================

def main():
    parser = argparse.ArgumentParser(description='makes napkin MNIST dataset')
    parser.add_argument('--mnist_dataloc', type=str, default='~/datasets')
    parser.add_argument('--save_loc', type=str, default='napkin_mnist4/base_data/')
    parser.add_argument('--num_samples',type=int, default=60_000)
    parser.add_argument('--num_threads',type=int, default=32)
    parser.add_argument('--massart', type=float, default=0.1) 
    args = parser.parse_args()

    args.mnist_dataloc = os.path.expanduser(args.mnist_dataloc)
    base_train = get_base_dict(args.mnist_dataloc, train=True)
    base_val = get_base_dict(args.mnist_dataloc, train=False)

    train_datadict = generate_datadict(base_train, num_samples=args.num_samples, num_threads=args.num_threads, massart=args.massart)
    val_datadict = generate_datadict(base_val, num_samples=args.num_samples, num_threads=args.num_threads, massart=args.massart)

    os.makedirs(args.save_loc, exist_ok=True)

    train_filename = os.path.join(args.save_loc, 'napkin_mnist_train.pkl')
    val_filename = os.path.join(args.save_loc, 'napkin_mnist_val.pkl')
    for dd, f in [(train_datadict, train_filename),
                  (val_datadict, val_filename)]:
        with open(f, 'wb') as pkl_file:
            pickle.dump(dd, pkl_file)


    return train_datadict, val_datadict


if __name__ == '__main__':
    out = main()










