import argparse
import glob
import random

import sys
import time

import torch
import torchvision
from PIL import Image
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from scenario_config import SCENARIO_CONFIG
from config import Config

from sae.model import AutoEncoder as PISA

class ImageSetLoader(Dataset):
    # Loads a set of images
    def __init__(self, paths):
        self.list_of_paths = paths

    def __len__(self):
        return len(self.list_of_paths)

    def __getitem__(self, x):

        image_paths = glob.glob(f'{self.list_of_paths[x]}/*.png')
        images = []
        for image_path in image_paths:
            image = Image.open(image_path)
            image = transforms.ToTensor()(image)
            images.append(image)
        images = torch.stack(images)

        return images

def train(
        data_path,
        cnn_path,
        image_width,
        data_dim,
        scenario,
        batch_size,
        test_size,
        device,
):
    folder_list = sorted(glob.glob(f'{data_path}*/*/'))
    random.shuffle(folder_list)

    train_size = len(folder_list) - test_size

    train_ds = ImageSetLoader(folder_list[:train_size])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

    test_ds = ImageSetLoader(folder_list[-test_size:])
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=True)

    set_size = SCENARIO_CONFIG[scenario]["num_agents"]

    from train_cnn import Autoencoder
    cnn_encoder = Autoencoder(image_width=image_width, latent_dim=data_dim).to(device)
    cnn_encoder.load_state_dict(torch.load(cnn_path))
    model = PISA(dim=data_dim, hidden_dim=data_dim * set_size).to(device)
    optimizer = Adam(model.parameters())

    import wandb
    run = wandb.init(
        project=Config.WANDB_PROJECT,
        entity=Config.WANDB_ENTITY,
        name=f"train_pisa_{scenario}",
        sync_tensorboard=True,
    )

    # Precompute batch
    batch = torch.arange(batch_size, device=device).repeat_interleave(
        set_size
    )

    num_epochs = 20000
    num_steps = 0
    best_test_loss = 99e9
    for epoch in range(num_epochs):

        train_loss = 0
        train_mse = 0
        train_size = 0
        for i, data in enumerate(train_dl):
            num_steps += 1

            images = data.to(device)  # Assuming the images are already loaded and normalized
            images = torch.flatten(images, start_dim=0, end_dim=1)  # [batches * agents, observations]
            images = cnn_encoder.encode(images)

            optimizer.zero_grad()

            # Forward pass
            _ = model(images, batch=batch)
            loss_vars = model.loss()
            loss = loss_vars["loss"]

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item() / batch_size
            train_mse += loss_vars["mse_loss"] / batch_size
            train_size += loss_vars["size_loss"] / batch_size

            # Print the loss every 100 steps
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, len(train_dl), loss.item() / batch_size))

            if (i + 1) % 1000 == 0:
                # Test pass
                test_loss = 0.0
                test_mse = 0.0
                test_size = 0.0
                for _, data in enumerate(test_dl):
                    images = data.to(device)
                    images = torch.flatten(images, start_dim=0, end_dim=1)  # [batches * agents, observations]
                    images = cnn_encoder.encode(images)
                    _ = model(images, batch=batch)
                    loss_vars = model.loss()
                    loss = loss_vars["loss"]
                    test_loss += loss.item()
                    test_mse += loss_vars["mse_loss"] / batch_size
                    test_size += loss_vars["size_loss"] / batch_size

                # Reconstruct some random test images from the first batch
                num_images = set_size * 2
                for _, data in enumerate(test_dl):
                    images = data.to(device)
                    images = torch.flatten(images, start_dim=0, end_dim=1)  # [batches * agents, observations]
                    precode = images.clone()
                    images = cnn_encoder.encode(images)
                    xr, _ = model(images, batch=batch)
                    xr = cnn_encoder.decode(xr)
                    cnn_xr = cnn_encoder.decode(images)
                    perm = model.encoder.get_x_perm()
                    viz = torchvision.utils.make_grid(
                        torch.cat(
                            [
                                precode[perm][:num_images],
                                cnn_xr[perm][:num_images],
                                xr[:num_images],
                            ],
                            dim=0,
                        ),
                        nrow=num_images
                    )
                    break

                wandb.log({
                    "train_loss": train_loss / 1000,
                    "train_mse_loss": train_mse / 1000,
                    "train_size_loss": train_size / 1000,
                    "test_loss": test_loss / len(test_dl),
                    "test_mse_loss": test_mse / len(test_dl),
                    "test_size_loss": test_size / len(test_dl),
                    "test_img": wandb.Image(viz),
                    "steps": num_steps,
                })
                train_loss, train_mse, train_size = 0.0, 0.0, 0.0

                # Save model after every 1000 steps
                file_str = f"weights/pisa_{scenario}_latest.pt"
                torch.save(model.state_dict(), file_str)

                new_test_loss = test_loss / len(test_dl)
                if new_test_loss < best_test_loss:
                    best_test_loss = new_test_loss
                    file_str = f"weights/pisa_{scenario}_best.pt"
                    torch.save(model.state_dict(), file_str)

    run.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='Train PISA on sampled data')
    parser.add_argument('--data_path', default=None, help='Path to Image sets')
    parser.add_argument('--cnn_path', default=None, help='Path to pre-trained convolutional autoencoder')
    parser.add_argument('--image_width', default=88, type=int, help='Width of images')
    parser.add_argument('--data_dim', default=128, type=int, help='Dimension of CNN encoder output')
    parser.add_argument('--scenario', default=None, help='Scenario sampled')
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--test_size', default=1024, type=int)
    parser.add_argument('--device', default='cuda')
    args = parser.parse_args()

    train(**vars(args))