# Generating Pictures to put in the Paper

import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

import matplotlib.pyplot as plt
import matplotlib as mpl

from datasets import Array3D, Features, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

# isort: split
from utils import *

def generate_conditional(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y_cond = batch['y']
        x = sample_conditional(model, y_cond, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(32, 32, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=dataset.column_names,
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )


# DATASET_PATH = '/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-75'
DATASET_PATH = Path('/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-gaussian-blur-2')
PATH = Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog')

def make_samples(path, lap):
    model = load_module(path / f'checkpoint_{lap}.pkl')
    dataset = load_from_disk(DATASET_PATH)
    dataset.set_format('numpy')
    testset = dataset['test'].select([16, 24, 23, 19, 25, 14, 27, 29])
    # breakpoint()
    seed = hash('___') % 2**16
    rng = inox.random.PRNG(seed)
    samples = generate_conditional(model, testset, rng, 8)['x']

    for i, img in enumerate(samples):
        to_pil(img).save(f'/data/vision/___/scratch/___ht/cifar_dir/art_blurry/checkpoint_{lap}/{i}.png')

def make_grid_picture():

    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.serif'] = ['DejaVu Serif']  # Default serif font in matplotlib

    img_files = ['0', '4', '8', '12', '16']       # 4 column titles
    column_titles = ['Lap 0', 'Lap 4', 'Lap 8', 'Lap 12', 'Lap 16'] 
    image_paths = [f'/data/vision/___/scratch/___ht/cifar_dir/art_blurry/checkpoint_{img_files[i // 4]}/{i % 4}.png' for i in range(20)] 
    rows, cols = 4, 5

    height_ratios = [0.3] + [1] * rows

    fig, axs = plt.subplots(
        rows + 1,
        cols,
        figsize=(cols * 3, (rows + 0.3) * 3),  # Adjust total height accordingly
        gridspec_kw={'height_ratios': height_ratios}
    )

    # First row for titles
    for ax in axs[0]:
        ax.axis('off')

    # Load and show images
    for i in range(20):
        row = (i % rows) + 1  # Offset by 1 due to title row
        col = i // rows
        img = Image.open(image_paths[i])
        axs[row, col].imshow(img)
        axs[row, col].axis('off')

    # Titles below each column (first row)
    for col in range(cols):
        axs[0, col].text(
            0.5, 0.5,
            column_titles[col],
            ha='center',
            va='center',
            fontsize=16,            # Larger font size
            fontweight='bold',      # Optional: bold for clarity
        )
        axs[0, col].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(hspace=0.2)
    plt.savefig("/data/vision/___/scratch/___ht/cifar_dir/art/image_grid.png", dpi=300, bbox_inches='tight')


def make_grid_picture2():
    # Font settings
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.serif'] = ['DejaVu Serif']

    # Inputs
    titles = ["Conditioned on", "Generated"]
    image_paths =   [f'/data/vision/___/scratch/___ht/cifar_dir/art_blurry/ground_truth/{i}.png' for i in range(8)] + \
                    [f'/data/vision/___/scratch/___ht/cifar_dir/art_blurry/checkpoint_16/{i}.png' for i in range(8)] 



    rows, cols = 2, 9  # Total grid size (including titles in first column)

    width_ratios = [0.3] + [1] * 8

    fig, axs = plt.subplots(
        rows,
        cols,
        figsize=(cols * 3, rows * 3),
        gridspec_kw={'width_ratios': width_ratios}
    )

    # Fill in the titles (first column)
    for row in range(rows):
        axs[row, 0].text(
            0.5, 0.5,
            titles[row],
            ha='center',
            va='center',
            fontsize=16,
            fontweight='bold',
            rotation=90
        )
        axs[row, 0].axis('off')

    # Fill in images
    for idx in range(16):
        row = idx // 8
        col = (idx % 8) + 1  # Offset by 1 because col=0 is the title
        img = Image.open(image_paths[idx])
        axs[row, col].imshow(img)
        axs[row, col].axis('off')

    plt.tight_layout()
    plt.savefig("/data/vision/___/scratch/___ht/cifar_dir/art_blurry/image_grid_with_titles.png", dpi=300, bbox_inches='tight')

def trainset_images():
    dataset = load_from_disk(DATASET_PATH)
    dataset.set_format('numpy')
    testset = dataset['test'].select([16, 24, 23, 19, 25, 14, 27, 29])
    for i, img in enumerate(testset['y']):
        to_pil(img).save(f'/data/vision/___/scratch/___ht/cifar_dir/art_blurry/ground_truth/{i}.png')

if __name__ == "__main__":
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog'), 17)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog'), 11)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog'), 5)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog'), 0)


    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw'), 0)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw'), 4)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw'), 8)
    # make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw'), 12)
    # make_grid_picture()


    trainset_images()
    make_samples(Path('/data/vision/___/scratch/___ht/cifar_dir/checkpoints_itnog_blur_nw'), 16)

    make_grid_picture2()