#!/usr/bin/env python

import inox
import io
import zipfile

from dawgz import job, schedule
from functools import partial
from torch import Tensor
from torch.utils import data
from torch_fidelity.fidelity import calculate_metrics
from tqdm import tqdm
from typing import *
import torchvision

# isort: split
from utils import *


class ZipDataset(data.Dataset):
    r"""Zip image dataset."""

    def __init__(self, archive: Path):
        self.images = []

        with zipfile.ZipFile(archive, mode='r') as file:
            for name in file.namelist():
                with file.open(name) as data:
                    img = Image.open(data)
                    img = img.convert('RGB')

                self.images.append(img)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, i: int) -> Tensor:
        return torchvision.transforms.PILToTensor()(self.images[i]) # pil_to_tensor(self.images[i])

    @staticmethod
    def zip(archive: Path, images: List):
        with zipfile.ZipFile(archive, mode='w') as file:
            for i, img in enumerate(images):
                buffer = io.BytesIO()
                img.save(buffer, 'png')
                file.writestr(f'IMG_{i}.png', buffer.getvalue())


def generate(checkpoint: Path, archive: Path, seed: int = None):
    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    # RNG
    seed = hash((checkpoint, seed)) % 2**16
    rng = inox.random.PRNG(seed)

    # Model
    model = load_module(checkpoint)

    static, arrays = model.partition()
    arrays = jax.device_put(arrays, replicated)
    model = static(arrays)

    # Generate
    images = []
    
    for _ in tqdm(range(0, 50000, 256), ncols=88):
        x = sample_any(
            model=model,
            shape=(256, 32 * 32 * 3),
            shard=True,
            sampler='ddim',
            steps=256,
            key=rng.split(),
        )
        x = unflatten(x, 32, 32)
        x = np.asarray(x)

        for img in map(to_pil, x):
            images.append(img)

    # Archive
    ZipDataset.zip(archive, images)


def fid(archive: Path, run: str, lap: int, seed: int):
    stats = calculate_metrics(
        input1=ZipDataset(archive),
        input2='cifar10-train',
        fid=True,
        isc=True,
    )

    fid = stats['frechet_inception_distance']
    isc = stats['inception_score_mean']

    print(f'{run},{lap},{seed},{fid},{isc}\n')

    with open(PATH / 'statistics.csv', mode='a') as f:
        f.write(f'{run},{lap},{seed},{fid},{isc}\n')


if __name__ == '__main__':

    # runpath = PATH / f'runs/{run}'
    runpath1 = Path('/data/vision/___/scratch/___ht/cifar_backup_original_paper_75') 
    runpath2 = Path('/data/vision/___/scratch/___ht/checkpoints_large_conditional')
    runpath3 = Path('/data/vision/___/scratch/___ht/unconditional_evaluation_model_cifar75')
    runpath = runpath1
    seed = 0

    jobs = []
    dirname = '' # The name of the very last dir in runpath

    for lap in [31]:
        checkpoint = runpath / f'checkpoint_{lap}.pkl'
        archive = runpath / f'archive_{lap}_{seed}.zip'

        if not checkpoint.exists():
            break
        
        print(f'LOG: {checkpoint=} \n {archive=} \n')
        generate_partial = partial(generate, checkpoint, archive, seed)()

        fid_partial = partial(fid, archive, dirname, lap, seed)()
        
