import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

class ImageToSingleToken(nn.Module):
    def __init__(self, image_size=14, feat_dim=512, nbr_heads=2, color_channels=3, patch_size=2, concat_pe=False):
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError()
        if feat_dim % nbr_heads != 0:
            raise ValueError()
        if concat_pe and feat_dim % 2 != 0:
            raise ValueError()
        feat_size = image_size // patch_size
        self.feat_dim = feat_dim
        self.feat_size = feat_size
        self.nbr_heads = nbr_heads
        self.feat_per_head = feat_dim // nbr_heads

        self.kv_embed = nn.Conv2d(
            color_channels, feat_dim if concat_pe else 2*feat_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        self.positional_encoding = nn.Parameter(torch.randn([feat_dim//2 if concat_pe else feat_dim, feat_size, feat_size]))
        self.class_token = nn.Parameter(torch.randn([feat_dim]))
        self.concat_pe = concat_pe

    def forward(self, x, return_attention=False):
        q = self.class_token.reshape(1, self.nbr_heads, self.feat_per_head, 1)

        k, v = self.kv_embed(x).chunk(2, dim=1)
        if self.concat_pe:
            k = torch.stack([k, self.positional_encoding[None].expand(k.shape[0], -1, -1, -1)], dim=2)
        else:
            k = k + self.positional_encoding[None]
        k = k.reshape(-1, self.nbr_heads, self.feat_per_head, self.feat_size**2)

        if self.concat_pe:
            v = torch.stack([v, self.positional_encoding[None].expand(k.shape[0], -1, -1, -1)], dim=2)
        else:
            v = v + self.positional_encoding[None]
        v = v.reshape(-1, self.nbr_heads, self.feat_per_head, self.feat_size**2)

        scores = torch.einsum("...fm,...fn->...mn", q, k)
        attention = nn.functional.softmax(scores, dim=-1)  # ..., heads, queries, keys
        out = torch.einsum("...mn,...fn->...fm", attention, v)
        out_class = out[..., -1].reshape(-1, self.feat_dim)

        if return_attention:
            return out_class, attention
        return out_class


class NeuralNetwork(nn.Module):
    def __init__(self, image_size=14, feat_dim=512, internal_dim=2048, color_channels=3, patch_size=2, nbr_heads=2, concat_pe=False):
        super().__init__()
        self.two_layer = nn.Sequential(
            ImageToSingleToken(image_size=image_size, feat_dim=feat_dim, color_channels=color_channels, patch_size=patch_size, nbr_heads=nbr_heads, concat_pe=concat_pe),
            nn.Linear(feat_dim, internal_dim, bias=False),
        )
        self.last_layers = nn.Sequential(
            nn.GELU(),
            nn.Linear(internal_dim, internal_dim, bias=True),
            nn.GELU(),
            nn.Linear(internal_dim, image_size * image_size * color_channels, bias=True),
        )
        self.color_channels = color_channels
        self.unflatten = nn.Unflatten(-1, (color_channels, image_size, image_size))

    def forward(self, x, return_intermediate=False):
        x_intermediate = self.two_layer(x)
        output = self.unflatten(self.last_layers(x_intermediate))
        if return_intermediate:
            return output, x_intermediate
        return output

def train_loop(
    dataloader,
    model,
    loss_fn,
    optimizer,
    device,
    batch_size,
    data_aug=False,
    equi_loss=False,
    equi_loss_coeff=1.0,
):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        with torch.no_grad():
            X, y = X.to(device), y.to(device)
            if data_aug and torch.rand(1).item() > .5:
                X = X.flip(-1)
        pred = model(X, return_intermediate=False)
        equi_loss_val = 0
        if equi_loss:
            flip_pred = model(X.flip(-1), return_intermediate=False).flip(-1)
            equi_loss_val = nn.functional.mse_loss(
                pred, flip_pred
            )
        loss = loss_fn(pred, X) + equi_loss_coeff*equi_loss_val
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 100 == 0:
            loss_val, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss_val:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn, device):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, X).item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
    return correct

def run_training(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device. Hyperparameters: {vars(args)}")

    # Data loading
    image_size = args.image_size
    if args.dataset == "CIFAR10":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.ColorJitter(.2, .2, .2, .2),
                transforms.Resize((32, 32)),
                transforms.RandomResizedCrop(image_size),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        training_data_cpu = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)
        test_data_cpu = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
        color_channels = 3
    elif args.dataset == "FMNIST":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((32, 32)),
                transforms.RandomResizedCrop(image_size),
                transforms.Normalize((0.5), (0.5)),
            ]
        )
        training_data_cpu = datasets.FashionMNIST(
            root="data", train=True, download=True, transform=transform)
        test_data_cpu = datasets.FashionMNIST(
            root="data", train=False, download=True, transform=transform)
        color_channels = 1
    elif args.dataset == "MNIST":
        raise NotImplementedError()
    else:
        raise ValueError()

    train_dataloader = DataLoader(training_data_cpu, batch_size=args.batch_size, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(test_data_cpu, batch_size=args.batch_size)

    # Model Initialization
    model = NeuralNetwork(image_size=image_size,
                          feat_dim=args.feat_dim,
                          internal_dim=args.internal_dim,
                          color_channels=color_channels,
                          nbr_heads=args.nbr_heads,
                          concat_pe=args.concat_pe,
                         ).to(device)

    if args.checkpoint_path:
        if os.path.exists(args.checkpoint_path):
            print(f"Loading model from checkpoint: {args.checkpoint_path}")
            model.load_state_dict(torch.load(args.checkpoint_path, map_location=device, weights_only=False)['sd'])
        else:
            print(f"WARNING: Checkpoint path '{args.checkpoint_path}' not found. Starting with a new model.")

    # Loss and Optimizer
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=args.epochs/2 if (args.equi_loss and args.late_equi_loss) else args.epochs,
    )

    # Training loop
    for t in range(args.epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(
            train_dataloader,
            model,
            loss_fn,
            optimizer,
            device,
            args.batch_size,
            data_aug=args.data_aug,
            equi_loss=args.equi_loss and (not args.late_equi_loss or t > args.epochs/2),
            equi_loss_coeff=args.equi_loss_coeff,
        )
        test_loop(test_dataloader, model, loss_fn, device)
        scheduler.step()

    print("Training finished!")

    # Save the model checkpoint
    ckpt = {
        'args': args,
        'sd': model.state_dict(),
    }
    checkpoint_name = f"{args.run_name}_final.pt"
    torch.save(ckpt, checkpoint_name)
    print(f"Model saved to {checkpoint_name}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PyTorch Training")

    # Training hyperparameters
    parser.add_argument("--learning-rate", type=float, default=1e-3, help="Optimizer learning rate")
    parser.add_argument("--weight-decay", type=float, default=1e-2, help="Optimizer weight decay")
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--feat-dim", type=int, default=64, help="Intermediate feature dimension")
    parser.add_argument("--internal-dim", type=int, default=2048, help="Internal feature dimension")
    parser.add_argument("--nbr-heads", type=int, default=2, help="Number of attention heads")
    parser.add_argument("--concat-pe", action="store_true", help="Concatenate rather than add positional embedding")
    parser.add_argument("--image-size", type=int, default=16, help="Image size")
    parser.add_argument("--batch-size", type=int, default=128, help="Training batch size")
    parser.add_argument("--dataset", type=str, default="CIFAR10")
    parser.add_argument("--data-aug", action="store_true", help="Hflip data aug.")
    parser.add_argument("--equi-loss", action="store_true", help="Equivariance loss.")
    parser.add_argument("--late-equi-loss", action="store_true", help="Equivariance loss only at last part.")
    parser.add_argument("--equi-loss-coeff", type=float, default=1.0, help="Coefficient of the equi loss")
    parser.add_argument("--checkpoint-path", type=str, default=None, help="Path to a model checkpoint to load before training.")
    parser.add_argument("--run-name", type=str, default="training_run", help="Name to use for saving the model and slurm log")

    args = parser.parse_args()

    run_training(args)
