import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, load_from_disk, Dataset
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

import matplotlib.pyplot as plt

from torch import Tensor
from torch.utils import data
from torch_fidelity.fidelity import calculate_metrics
from torchvision.transforms.functional import pil_to_tensor

import io
import zipfile


from utils import *

# Smaller Config
CONFIG_ORIGINAL = {
    # Data
    'corruption': 75,
    # Architecture
    'hid_channels': (128, 256, 384),
    'hid_blocks': (5, 5, 5),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {1: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'sde': {'a': 1e-3, 'b': 1e2},
    'heuristic': None,
    'discrete': 256,
    'maxiter': 1,
    # Training
    'epochs': 256,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 2e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.9999,
}

# Larger Config
CONFIG = {
    # Data
    'corruption': 75,
    # Architecture
    'hid_channels': (256, 384, 512), # (128, 256, 384),
    'hid_blocks': (6, 6, 6), # (5, 5, 5),
    'kernel_size': (3, 3),
    'emb_features': 512, # 256,
    'heads': {1: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'sde': {'a': 1e-3, 'b': 1e2},
    'heuristic': None,
    'discrete': 256,
    'maxiter': 1,
    # Training
    'epochs': 256,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 2e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.9999,
}



def corrupt(corruption, dataset, rng):

    def transform(row):
        # x = from_pil(row['x'])
        x = row['x']
        A = rng.uniform(shape=(32, 32, 1)) > corruption / 100
        y = 1e-3 * rng.normal(shape = A.shape) + A * x
        y = np.array(y)
        return {'y': y}

    types = {
        'y': Array3D(shape=(32, 32, 3), dtype='float32'),
    }

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['x'],
        keep_in_memory=True,
        num_proc=1,
    )


class MyDict():
    def __init__(self, dictionary: dict):
        self.dictionary = dictionary
    def __getattr__(self, name: str):
        return self.dictionary[name]
    
config = MyDict(CONFIG)



class ZipDataset(data.Dataset):
    r"""Zip image dataset."""

    def __init__(self, archive: Path):
        self.images = []

        with zipfile.ZipFile(archive, mode='r') as file:
            for name in file.namelist():
                with file.open(name) as data:
                    img = Image.open(data)
                    img = img.convert('RGB')

                self.images.append(img)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, i: int) -> Tensor:
        return pil_to_tensor(self.images[i])

    @staticmethod
    def zip(archive: Path, images: List):
        with zipfile.ZipFile(archive, mode='w') as file:
            for i, img in enumerate(images):
                buffer = io.BytesIO()
                img.save(buffer, 'png')
                file.writestr(f'IMG_{i}.png', buffer.getvalue())





def fid(archive: Path, checkpoint_name):
    stats = calculate_metrics(
        input1=ZipDataset(archive),
        input2='cifar10-train',
        fid=True,
        isc=True,
    )

    fid = stats['frechet_inception_distance']
    isc = stats['inception_score_mean']
    print(f'! {checkpoint_name}: {fid=}, {isc=}')



def measure(checkpoint_name, archive_name, runpath):

    with open(runpath / checkpoint_name, 'rb') as f:
        previous = pickle.load(f)


    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    dataset = load_from_disk(f'/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-gaussian-blur-2')
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    testset_yA = dataset['test']

    y_eval = testset_yA[:16]['y']
    y_eval = jax.device_put((y_eval, ), distributed)


    # RNG
    checkpoint = runpath / checkpoint_name
    seed = hash(checkpoint) % 2**16
    rng = inox.random.PRNG(seed)
    archive = runpath / archive_name
    print(checkpoint.exists())

    def generate(checkpoint: Path, archive: Path, seed: int = None):
        # Sharding
        jax.config.update('jax_threefry_partitionable', True)

        mesh = jax.sharding.Mesh(jax.devices(), 'i')
        replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

        # Model
        model = load_module(checkpoint)

        static, arrays = model.partition()
        arrays = jax.device_put(arrays, replicated)
        model = static(arrays)

        # Generate
        images = []

        for i in tqdm(range(0, 50000, 256), ncols=88):
            x = sample_conditional(
                model=model,
                y_cond = trainset_yA['y'][i: i + 256],
                key=rng.split(),
                shard=True,
                sampler='ddpm',
                steps=256,
                maxiter=config.maxiter,
            )
            # x = unflatten(x, 32, 32)
            x = np.asarray(x)

            for img in map(to_pil, x):
                images.append(img)


        # Archive
        ZipDataset.zip(archive, images)

    

    generate(checkpoint, archive, seed)
    fid(archive, checkpoint_name)

if __name__ == "__main__":
    # laps = [_ for _ in range(31)]
    # laps.reverse()
    # for lap in laps:
    #     print(f'Starting lap {lap}')
    #     measure(f'checkpoint_{lap}.pkl', f'archive_{lap}.zip', Path('/data/vision/___/scratch/___ht/checkpoints_1x_conditional'))
    laps = [_ for _ in range(21)]
    laps.reverse()
    for lap in laps:
        measure(
        f'checkpoint_{lap}.pkl',
        f'archive_checkpoint_{lap}.zip',
        Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw')
    )

