# Pre-train a CNN to encode image data from Melting Pot
# Note that Melting Pot environments provide observations
# in sizes 88x88 (most common) and 40x40 (Cooking / Matrix envs).

import argparse
import glob
import random
import sys
import time
import wandb
import torch
import torchvision
from torch.optim import Adam
import torch.nn.functional as F

import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from config import Config
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

from torch import nn

class Autoencoder(nn.Module):
    def __init__(self, image_width, latent_dim=128):
        super(Autoencoder, self).__init__()

        self.scaled_width = image_width // 8

        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # 40 -> 20, 88 -> 44
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 20 -> 10, 44 -> 22,
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 10 -> 5, 22 -> 11
            nn.ReLU()
        )

        self.fc_encoder = nn.Linear(64 * self.scaled_width * self.scaled_width, latent_dim)
        self.fc_decoder = nn.Linear(latent_dim, 64 * self.scaled_width * self.scaled_width)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        encoded = self.fc_encoder(x)
        return encoded

    def decode(self, z):
        z = self.fc_decoder(z)
        z = z.view(z.size(0), 64, self.scaled_width, self.scaled_width)
        decoded = self.decoder(z)
        return decoded

    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return decoded

class ImageLoader(Dataset):
    # Loads images ignoring set structure

    def __init__(self, paths):
        self.list_of_paths = paths

    def __len__(self):
        return len(self.list_of_paths)

    def __getitem__(self, x):
        image_path = self.list_of_paths[x]
        image = Image.open(image_path)
        image = transforms.ToTensor()(image)
        return image

def evaluate(
        best_test_loss,
        scenario,
        model,
        train_dl,
        test_dl,
        criterion,
        train_loss,
        epoch,
        num_steps,
        device
):

    # Test pass
    test_loss = 0
    for i, data in enumerate(test_dl):
        images = data.to(device)
        outputs = model(images)
        loss = criterion(outputs, images)
        test_loss += loss.item()

    # Reconstruct some random test images from the first batch
    num_images = 4
    for i, data in enumerate(test_dl):
        images = data.to(device)
        outputs = model(images)
        indices = torch.randperm(len(outputs))[:num_images]
        viz = torchvision.utils.make_grid(
            torch.cat(
                [
                    images[indices],
                    outputs[indices],
                ],
                dim=0,
            ),
            nrow=num_images
        )
        break

    wandb.log({
        "train_loss": train_loss,
        "test_loss": test_loss / len(test_dl),
        "test_img": wandb.Image(viz),
        "epoch": epoch,
        "steps": num_steps,
    })

    # Save model after every 1000 steps
    file_str = f"weights/cnn_{scenario}_latest.pt"
    torch.save(model.state_dict(), file_str)

    if test_loss < best_test_loss:
        best_test_loss = test_loss
        torch.save(model.state_dict(), f"weights/cnn_{scenario}_best.pt")

    return best_test_loss

def train(
        data_path,
        image_width,
        latent_dim,
        scenario,
        batch_size,
        test_size,
        device,
):
    image_list = sorted(glob.glob(f'{data_path}*/*/*.png'))
    random.shuffle(image_list)

    train_size = len(image_list) - test_size

    train_ds = ImageLoader(image_list[:train_size])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    test_ds = ImageLoader(image_list[-test_size:])
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    model = Autoencoder(image_width=image_width, latent_dim=latent_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=1e-3)

    run = wandb.init(
        project=Config.WANDB_PROJECT,
        entity=Config.WANDB_ENTITY,
        name=f"train_cnn_{scenario}",
        sync_tensorboard=True,
    )

    num_epochs = 1000
    num_steps = 0
    best_test_loss = 99e9
    for epoch in range(num_epochs):

        train_loss = 0
        for i, data in enumerate(train_dl):
            num_steps += 1

            images = data.to(device)  # Assuming the images are already loaded and normalized
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, images)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Print the loss every 100 steps
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, len(train_dl), loss.item()))

            # Evaluate after 1000 steps
            if (i + 1) % 1000 == 0:
                best_test_loss = evaluate(
                    best_test_loss,
                    scenario,
                    model,
                    train_dl,
                    test_dl,
                    criterion,
                    train_loss / (i + 1),
                    epoch,
                    num_steps,
                    device
                )

        # Evaluate after epoch
        best_test_loss = evaluate(
            best_test_loss,
            scenario,
            model,
            train_dl,
            test_dl,
            criterion,
            train_loss / len(train_dl),
            epoch,
            num_steps,
            device
        )

    run.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='Train CNN on sampled data')
    parser.add_argument('--data_path', default=None, help='Path to Image sets')
    parser.add_argument('--image_width', default=88, type=int, help='Width of images')
    parser.add_argument('--latent_dim', default=128, type=int)
    parser.add_argument('--scenario', default=None, help='Scenario sampled')
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--test_size', default=1024, type=int)
    parser.add_argument('--device', default='cuda')
    args = parser.parse_args()

    train(**vars(args))
