# dgm_eval cannot calculate inseption score, thus we use this script

#!/usr/bin/env python
import time
import inox
import io
import zipfile

from dawgz import job, schedule
from functools import partial
from torch import Tensor
from torch.utils import data
from torch_fidelity.fidelity import calculate_metrics
from tqdm import tqdm
from typing import *
import torchvision

from datasets import load_from_disk

# isort: split
from utils import *

TEST_MODE = True
NUM_DATA_GENERATE = 50_000 if not TEST_MODE else 256

class ZipDataset(data.Dataset):
    r"""Zip image dataset."""

    def __init__(self, archive: Path):
        self.images = []

        with zipfile.ZipFile(archive, mode='r') as file:
            for name in file.namelist():
                with file.open(name) as data:
                    img = Image.open(data)
                    img = img.convert('RGB')

                self.images.append(img)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, i: int) -> Tensor:
        return torchvision.transforms.PILToTensor()(self.images[i]) # pil_to_tensor(self.images[i])

    @staticmethod
    def zip(archive: Path, images: List):
        with zipfile.ZipFile(archive, mode='w') as file:
            for i, img in enumerate(images):
                buffer = io.BytesIO()
                img.save(buffer, 'png')
                file.writestr(f'IMG_{i}.png', buffer.getvalue())


def fid(ref_path: Path, gen_path: Path, name:str, lap: int):
    
    for i in ref_path.iterdir():
        Image.load
    
    stats1 = calculate_metrics(
        input1=gen_path,
        input2=ref_path,
        fid=True,
        isc=True,
    )

    fid = stats['frechet_inception_distance']
    isc = stats['inception_score_mean']

    stats2 = calculate_metrics(
        input1=gen_path,
        input2=ref_path,
        fid=True,
        feature_extractor='dinov2-vit-l-14', 
        # other options: 'dinov2-vit-b-14' 'dinov2-vit-s-14' 'dinov2-vit-g-14'
    )

    fid_dinov2 = stats2['frechet_inception_distance']

    print(f'{name},{lap},{fid},{isc},{fid_dinov2}\n')

def normalize_transform(batch):
    y, A = batch['y'], batch['A']
    y = y * 4 / 256 - 2
    return {'y': y, 'A': A}


def generate(checkpoint_path: Path|str, seed: int = 42) -> list[Image]:
    if isinstance(checkpoint_path, str):
        checkpoint_path = Path(checkpoint_path)

    model = load_module(checkpoint_path)
    dataset = load_from_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{corruption}' + ('_test' if TEST_MODE else ''))
    dataset.set_format('numpy')
    dataset = dataset.map(
        normalize_transform,
        batched=True,
        batch_size=1024,
        num_proc=4,
        desc = 'normalizing'
    )
    
    counter = Counter()
    rng = inox.random.PRNG(seed = seed)
    
    def transform(batch):
        if counter.get() == NUM_DATA:
            return
        y_cond = np.asarray(batch['y'])

        x = sample(
                model=model,
                key=rng.split(),
                shard=True,
                sampler='ddpm',
                steps=256,
                maxiter=1,
            )

        x = np.asarray(x)

        for each in x:
            to_pil(each).save(save_dir / f'{counter.get()}.png')
            counter.inc()
            if counter.get() == NUM_DATA:
                break


    return dataset.map(
        transform,
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
        num_proc = 1
    )

def generate_conditional(checkpoint_path:str, corruption: int, seed: int = 42) -> list[Image]:
    model = load_module(checkpoint_path)
    dataset = load_from_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{corruption}' + ('_test' if TEST_MODE else ''))
    dataset = dataset.select(range(NUM_DATA_GENERATE))
    dataset.set_format('numpy')
    dataset = dataset.map(
        normalize_transform,
        batched=True,
        batch_size=1024,
        num_proc=4,
        desc = 'normalizing'
    )
    
    counter = Counter()
    rng = inox.random.PRNG(seed)

    def transform(batch):
        y_cond = batch['y']

        x = sample_conditional(
                model,
                y_cond,
                rng.split(),
                shard=True,
                sampler='ddpm',
                steps=64,
                maxiter=3
                )

        x = np.asarray(x)
        return {'x': x}

    x = dataset.map(
        transform,
        keep_in_memory=True,
        batched=True,
        batch_size=256,
        drop_last_batch=True,
        num_proc = 1,
        remove_columns = ['y', 'A']
    )

    breakpoint()

if __name__ == '__main__':

    checkpoint_path = '/data/vision/___/scratch/___ht/celeba_dir/checkpoints_debugged/checkpoint_14.pkl'
    generate_conditional(checkpoint_path, corruption=50)

    # gen_path = Path('/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/conditional/mask50/mask50_14/')

    # fid(ref_path=ref_path, gen_path=gen_path, name='mask50_14')
        
