import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset, Dataset, Array3D, Features, load_from_disk, concatenate_datasets
import inox
from tqdm import trange, tqdm


import wandb

import re

from utils import *
from jax.scipy.signal import convolve2d

TEST_MODE = False
RUN_NAME = "warmstart"

CHECKPOINT_PATH = Path('/data/vision/___/scratch/___ht/celeba_dir/checkpoints_test') if TEST_MODE \
else Path('/data/vision/___/scratch/___ht/celeba_dir/checkpoints_' + RUN_NAME)

DATASET_PATH = '/data/vision/___/scratch/___ht/celeba_64_mask50_test/' if TEST_MODE \
else Path('/data/vision/___/scratch/___ht/celeba_64_mask50/')

CONFIG = {
    # Data
    # 'duplicate': 2,
    'corruption': 50,
    'img_shape': (64, 64, 3),
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}
config = MyDict(CONFIG)

def corrupt(rng, corruption, dataset: Dataset):

    def transform(row):
        x = np.asarray(row['x'])
        A = rng.bernoulli(p = 1 - corruption / 100, shape = config.img_shape[:2] + (1, ))
        x += 2
        y = np.array(A * x)
        y -= 2
        return {'y': y}
    
    types = {
        'y': Array3D(shape=config.img_shape, dtype='float32'),
    }

    return dataset.map(
        transform,
        remove_columns=dataset.column_names,
        features=Features(types),
        keep_in_memory=True,
        num_proc=1
    )

def generate_conditional(model, dataset, rng, batch_size, sde, **kwargs):
    def transform(batch):
        y_cond = np.asarray(batch['y'])

        x = sample_conditional(
                model,
                y_cond,
                rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter
                )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(64, 64, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )
    
def generate(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample(model, y, A, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(64, 64, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def train_conditional(lap, runid):
    run = wandb.init(
            project='priors-celeba-mask-conditional',
            id=runid,
            resume='allow',
            dir=CHECKPOINT_PATH,
            config=CONFIG,
            name=f'celeba_itnog_lap[{start_lap},)' + ('_test' if TEST_MODE else '_' + RUN_NAME)
        )

    # RNG
    seed = hash(('celebA', lap)) % (1<<16)
    rng = inox.random.PRNG(seed)

    if TEST_MODE:
        dataset = load_from_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask50_test/', keep_in_memory=True)
    else:
        dataset = load_from_disk('/data/vision/___/scratch/___ht/celeba_64_mask50/', keep_in_memory=True)
    
    dataset.set_format('numpy')
    trainset = dataset
    trainset = trainset.rename_column('y', 'x')
    trainset = corrupt(rng, 0, trainset)
    trainset.set_format(type='numpy', columns=['y'])

    trainset = trainset.rename_column('y', 'x')

    def normalize_transform(row):
        x = row['x']
        x = x * 4 / 256 - 2
        return {'x': x}

    trainset = trainset.map(
        normalize_transform,
    )

    ds = concatenate_datasets([trainset, dataset.remove_columns('y')], axis = 1)
    ds.set_format('numpy')

    # ds = ds.select(range(1 << 10)) if TEST_MODE else ds.select(range(1 << 14)) # for the first lap, use a subset of the data

    breakpoint()

    y_fit, A_fit = jnp.array(dataset[:16384]['y']), jnp.array(dataset[:16384]['A'])
    y_fit, A_fit = jax.device_put((y_fit, A_fit), distributed)

    mu_x, cov_x = fit_moments(
        features=64 * 64 * 3,
        rank=64,
        shard=True,
        A=inox.Partial(measure, A_fit),
        y=flatten(y_fit),
        cov_y=1e-3**2, # TODO: or 1e-2**2?
        sampler='ddim',
        sde=sde,
        steps=256,
        maxiter=None,
        key=rng.split(),
    )


if __name__ == "__main__":
    CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

    # Automatically finding the start_lap by looping over the files in directory
    checkpoint_name_regex = r'checkpoint_(\d+).pkl'
    max_checkpoint = -1
    for child in Path(CHECKPOINT_PATH).iterdir():
        if not child.is_file():
            continue
        if not re.fullmatch(checkpoint_name_regex, child.name):
            continue
        max_checkpoint = max(max_checkpoint, int(re.fullmatch(checkpoint_name_regex, child.name).group(1)))
        
    start_lap = max_checkpoint + 1

    runid = wandb.util.generate_id()

    for lap in range(start_lap, 64):
        train_conditional(lap, runid)
