from __future__ import annotations

import os
from random import Random

import torch
from torch import nn
import torchvision
import torchvision.transforms.v2 as tv
from tqdm.auto import tqdm

from models import AutoEncoder
from spriteworld import SpriteWorld

# CONSTANTS
OUTPUT_DIR = 'stack-images'

# Image parameters
IMAGE_SIZE = 32
IMAGE_CHANNELS = 3
IMAGE_SHAPE = (IMAGE_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

# Architecture parameters
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LATENT_SIZE = 16
CNN_HIDDEN_LAYERS = 4
CNN_HIDDEN_CHANNELS = 64
CNN_CHANNEL_MULTIPLIER = 1

STACK_LATENT_SIZE = 64
MLP_HIDDEN_SIZE = 256

# Training hyper-parameters
BATCH_SIZE = 64
CLIP_GRADIENTS = 1
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-2
STEPS = 100000
LOGGING_STEP = 10000
SEED = 0


pil_to_tensor = tv.Compose([
    tv.PILToTensor(),
    tv.ToDtype(torch.float32, scale=True)
])


class ImagesDataset(torch.utils.data.IterableDataset):
    def __init__(self, spriteworld,
                 batch_size: int,
                 rng: Random | None = None) -> None:
        self.spriteworld = spriteworld
        self.batch_size = batch_size
        self.rng = rng

    def __iter__(self) -> torch.Tensor:
        if self.rng is None:
            self.rng = Random(torch.initial_seed())
        while True:
            yield torch.stack([pil_to_tensor(
                self.spriteworld.sample(self.rng)['image'].convert('RGB')
            ) for _ in range(self.batch_size)])


class Stack(nn.Module):
    def __init__(self,
                 latent_size: int,
                 item_size: int,
                 hidden_size: int) -> None:
        super().__init__()

        self.latent_size = latent_size
        self.item_size = item_size
        self.hidden_size = hidden_size

        self.empty = nn.Parameter(torch.zeros(latent_size))
        self.none = nn.Parameter(torch.zeros(item_size))
        torch.nn.init.uniform_(self.empty)
        torch.nn.init.uniform_(self.none)
        self.push_network = nn.Sequential(
            nn.Linear(latent_size + item_size, hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_size, latent_size)
        )
        self.pop_network = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_size, latent_size + item_size)
        )

    def push(self, state: torch.Tensor, item: torch.Tensor) -> torch.Tensor:
        state_and_item = torch.cat((state, item), dim=-1)
        return self.push_network(state_and_item)

    def pop(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        state_and_item = self.pop_network(state)
        state = state_and_item[..., :self.latent_size]
        item = state_and_item[..., self.latent_size:]
        return state, item


def image_loss(pred, target):
    return mse_loss(pred, target) + nn.L1Loss()(pred, target)


mse_loss = nn.MSELoss()

rng = Random(SEED)
torch.manual_seed(SEED)

spriteworld = SpriteWorld(32, 1, (0, 0, 0), 8, (64, 255), (64, 255), 3, 3, 3, 3)
dataset = ImagesDataset(spriteworld, BATCH_SIZE)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=None,
                                         num_workers=1,
                                         prefetch_factor=2)

autoencoder = AutoEncoder(IMAGE_SIZE,
                          LATENT_SIZE,
                          CNN_HIDDEN_LAYERS,
                          CNN_HIDDEN_CHANNELS,
                          CNN_CHANNEL_MULTIPLIER).to(DEVICE).train()

stack = Stack(STACK_LATENT_SIZE,
              LATENT_SIZE,
              MLP_HIDDEN_SIZE).to(DEVICE).train()

models = torch.nn.ModuleList([autoencoder, stack])
optimiser = torch.optim.AdamW(models.parameters(),
                              lr=LEARNING_RATE,
                              weight_decay=WEIGHT_DECAY)

print('Initialised models')

step = 0
with tqdm(total=STEPS) as progress_bar:
    for batch in dataloader:
        losses = {}

        image_batch = batch.to(DEVICE)

        autoencoder_prediction = autoencoder(image_batch)
        losses['autoencoder'] = image_loss(autoencoder_prediction, image_batch)

        with torch.no_grad():
            encoded_image_batch = autoencoder.encode(image_batch)

        # Start with some items in the stack already
        initial_size = step % 7
        initial_items = encoded_image_batch[:initial_size]
        state_batch = encoded_image_batch[initial_size:]

        initial_stack = stack.empty.unsqueeze(0)
        with torch.no_grad():
            for item in initial_items:
                initial_stack = stack.push(initial_stack, item.unsqueeze(0))
        stack_batch = initial_stack.repeat(len(state_batch), 1)

        # Pushing then popping should return the same stack and item
        pushed_stack = stack.push(stack_batch, state_batch)
        popped_stack, popped_items = stack.pop(pushed_stack)
        losses['stack'] = mse_loss(popped_stack, stack_batch)
        losses['item'] = mse_loss(popped_items, state_batch)

        # Popping from empty stack should return the sentinel NONE item
        empty_stack = stack.empty.unsqueeze(0)
        popped_empty_stack, popped_empty_item = stack.pop(empty_stack)
        losses['empty_stack'] = mse_loss(popped_empty_stack, empty_stack)
        losses['empty_item'] = mse_loss(popped_empty_item,
                                        stack.none.unsqueeze(0))

        loss = sum(losses.values())

        optimiser.zero_grad()
        loss.backward()
        if CLIP_GRADIENTS:
            torch.nn.utils.clip_grad_value_(models.parameters(),
                                            CLIP_GRADIENTS)
        optimiser.step()

        progress_bar.update(1)
        step += 1
        if step % LOGGING_STEP == 0:
            print(f'Ep {step}\nLoss: {loss:g} <-',
                  ' '.join(f'{k}={v:g}' for k, v in losses.items()))

        if step >= STEPS:
            break

eval_dataset = ImagesDataset(spriteworld, 4, Random(SEED + 1))
eval_images = next(iter(eval_dataset))
with torch.no_grad():
    autoencoder.eval()
    stack.eval()

    eval_items = autoencoder.encode(eval_images.to(DEVICE))
    eval_stack = stack.empty.unsqueeze(0)
    for item in eval_items:
        eval_stack = stack.push(eval_stack, item.unsqueeze(0))

    output_items = []
    for _ in eval_items:
        eval_stack, output_item = stack.pop(eval_stack)
        output_items.append(output_item.squeeze(0))

    output_images = autoencoder.decode(torch.stack(output_items))


os.makedirs(OUTPUT_DIR, exist_ok=True)
for i, image in enumerate(eval_images, start=1):
    torchvision.utils.save_image(image, f'{OUTPUT_DIR}/original-{i}.png')
for i, image in enumerate(reversed(output_images), start=1):
    torchvision.utils.save_image(image, f'{OUTPUT_DIR}/output-{i}.png')
print(f'Images saved to {OUTPUT_DIR}')
