# Evaluation Conditional Perfomance of the EM + Moment Matching method

import inox
import io
import zipfile

from dawgz import job, schedule
from functools import partial
from torch import Tensor
from datasets import Array3D, Features, load_from_disk
from torch.utils import data
from torch_fidelity.fidelity import calculate_metrics
from tqdm import tqdm
from typing import *
import torchvision

# isort: split
from utils import *


CONFIG = {
    # Data
    'corruption': 75,
    # Architecture
    'hid_channels': (128, 256, 384),
    'hid_blocks': (5, 5, 5),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {1: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'sde': {'a': 1e-3, 'b': 1e2},
    'heuristic': None,
    'discrete': 256,
    'maxiter': 1,
    # Training
    'epochs': 256,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 2e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.9999,
}

class ZipDataset(data.Dataset):
    r"""Zip image dataset."""

    def __init__(self, archive: Path):
        self.images = []

        with zipfile.ZipFile(archive, mode='r') as file:
            for name in file.namelist():
                with file.open(name) as data:
                    img = Image.open(data)
                    img = img.convert('RGB')

                self.images.append(img)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, i: int) -> Tensor:
        return torchvision.transforms.PILToTensor()(self.images[i]) # pil_to_tensor(self.images[i])

    @staticmethod
    def zip(archive: Path, images: List):
        with zipfile.ZipFile(archive, mode='w') as file:
            for i, img in enumerate(images):
                buffer = io.BytesIO()
                img.save(buffer, 'png')
                file.writestr(f'IMG_{i}.png', buffer.getvalue())

def generate(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample(model, y, A, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(32, 32, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def generate_data(archive):

    PATH = Path('/data/vision/___/scratch/___ht/cifar_backup_original_paper_75')

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('evaluation of EM + Moment Matching') % 2**16
    rng = inox.random.PRNG(seed)


    # SDE
    sde = VESDE(**CONFIG.get('sde'))


    # Data
    dataset = load_from_disk(f'/data/vision/___/scratch/___ht/diffusion-priors/experiments/cifar/hf/cifar-mask-{CONFIG['corruption']}')
    dataset.set_format('numpy')

    trainset_yA = dataset['train']

    previous = load_module('/data/vision/___/scratch/___ht/cifar_backup_original_paper_75/checkpoint_31.pkl')

    static, arrays = previous.partition()
    arrays = jax.device_put(arrays, replicated)
    previous = static(arrays)

    trainset = generate(
        model=previous,
        dataset=trainset_yA,
        rng=rng,
        batch_size=CONFIG['batch_size'],
        shard=True,
        sampler=CONFIG['sampler'],
        sde=sde,
        steps=CONFIG['discrete'],
        maxiter=CONFIG['maxiter'],
    )

    images = []
    for img in map(to_pil, trainset['x']):
        images.append(img)


    ZipDataset.zip(archive, images)


def fid(archive: Path):
    stats = calculate_metrics(
        input1=ZipDataset(archive),
        input2='cifar10-train',
        fid=True,
        isc=True,
    )

    fid = stats['frechet_inception_distance']
    isc = stats['inception_score_mean']

    # print(f'{run},{lap},{seed},{fid},{isc}\n')

    # with open(PATH / 'statistics.csv', mode='a') as f:
    #     f.write(f'{run},{lap},{seed},{fid},{isc}\n')


if __name__ == "__main__":
    archive = Path('/data/vision/___/scratch/___ht/eval_mm/eval_mm_archive_checkpoint_31.zip')
    generate_data(archive)
    fid(archive)

