import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model / data parameters
num_classes = 10
input_shape = (1, 28, 28)  # PyTorch uses (channels, height, width)
batch_size = 2000
# batch_size = 64
num_epochs = 20
learning_rate = 0.001
weight_decay = 0.0001

# Define data augmentation and normalization for training
train_transform = transforms.Compose([
    transforms.RandomRotation(2),
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

# Define normalization for testing (no data augmentation)
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

# Load the datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)

# Create validation split
validation_split = 0.1
shuffle_dataset = True
random_seed = 42
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

# Create data loaders
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the MLP module
class MLP(nn.Module):
    def __init__(self, in_features, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(in_features, units))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
            in_features = units
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

# Define the Patches module
class Patches(nn.Module):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def forward(self, images):
        batch_size, channels, height, width = images.size()
        patch_size = self.patch_size
        patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        patches = patches.contiguous().view(batch_size, channels, -1, patch_size, patch_size)
        patches = patches.permute(0, 2, 1, 3, 4)
        patch_dim = channels * patch_size * patch_size
        patches = patches.contiguous().view(batch_size, -1, patch_dim)
        return patches

# Define the Patch Encoder module
class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim, patch_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = nn.Linear(patch_dim, projection_dim)
        self.position_embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, patches):
        positions = torch.arange(0, self.num_patches, device=patches.device).unsqueeze(0)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

# Define the Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout_rate):
        super(TransformerBlock, self).__init__()
        self.layernorm1 = nn.LayerNorm(dim, eps=1e-6)
        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout_rate)
        self.layernorm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout_rate),
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x_norm = self.layernorm1(x)
        x_norm = x_norm.permute(1, 0, 2)  # Required shape for MultiheadAttention: (seq_len, batch_size, embed_dim)
        attn_output, _ = self.mha(x_norm, x_norm, x_norm)
        attn_output = attn_output.permute(1, 0, 2)
        x = x + self.dropout(attn_output)
        x_norm = self.layernorm2(x)
        x = x + self.mlp(x_norm)
        return x

# Define the Vision Transformer model
class ViT(nn.Module):
    def __init__(self, image_size=28, patch_size=14, num_classes=10, dim=96,
                 depth=16, heads=4, mlp_dim=2048, dropout_rate=0.1):
        super(ViT, self).__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size."

        self.num_patches = (image_size // patch_size) ** 2
        patch_dim = 1 * patch_size * patch_size  # Since MNIST images have 1 channel
        self.patches = Patches(patch_size)
        self.patch_encoder = PatchEncoder(self.num_patches, dim, patch_dim)

        self.transformer = nn.ModuleList([
            TransformerBlock(dim=dim, num_heads=heads, mlp_dim=dim*2, dropout_rate=dropout_rate)
            for _ in range(depth)
        ])

        self.layernorm = nn.LayerNorm(dim, eps=1e-6)
        self.mlp_head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            MLP(in_features=dim * self.num_patches, hidden_units=[2048, 1024], dropout_rate=0.5),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        patches = self.patches(x)
        x = self.patch_encoder(patches)

        for transformer_block in self.transformer:
            x = transformer_block(x)

        x = self.layernorm(x)
        logits = self.mlp_head(x)
        return logits

# Instantiate the model, define the optimizer and loss function
model = ViT().to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

# Functions to calculate accuracy
def calculate_accuracy(outputs, labels):
    _, preds = outputs.max(1)
    correct = preds.eq(labels).sum()
    accuracy = correct.float() / labels.size(0)
    return accuracy

def calculate_top5_accuracy(outputs, labels):
    _, preds = outputs.topk(5, dim=1)
    correct = preds.eq(labels.view(-1, 1).expand_as(preds))
    top5_correct = correct.sum()
    top5_accuracy = top5_correct.float() / labels.size(0)
    return top5_accuracy

# # Training and validation loop
# best_val_accuracy = 0.0
# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0
#     total_accuracy = 0
#     for images, labels in train_loader:
#         images = images.to(device)
#         labels = labels.to(device)
#
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#
#         total_loss += loss.item() * images.size(0)
#         total_accuracy += calculate_accuracy(outputs, labels).item() * images.size(0)
#
#     avg_loss = total_loss / len(train_loader.sampler)
#     avg_accuracy = total_accuracy / len(train_loader.sampler)
#
#     # Validation
#     model.eval()
#     val_total_loss = 0
#     val_total_accuracy = 0
#     val_total_top5_accuracy = 0
#     with torch.no_grad():
#         for images, labels in valid_loader:
#             images = images.to(device)
#             labels = labels.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             val_total_loss += loss.item() * images.size(0)
#             val_total_accuracy += calculate_accuracy(outputs, labels).item() * images.size(0)
#             val_total_top5_accuracy += calculate_top5_accuracy(outputs, labels).item() * images.size(0)
#
#     val_avg_loss = val_total_loss / len(valid_loader.sampler)
#     val_avg_accuracy = val_total_accuracy / len(valid_loader.sampler)
#     val_avg_top5_accuracy = val_total_top5_accuracy / len(valid_loader.sampler)
#
#     print(f'Epoch [{epoch+1}/{num_epochs}], '
#           f'Train Loss: {avg_loss:.4f}, Train Acc: {avg_accuracy*100:.2f}%, '
#           f'Val Loss: {val_avg_loss:.4f}, Val Acc: {val_avg_accuracy*100:.2f}%, '
#           f'Val Top-5 Acc: {val_avg_top5_accuracy*100:.2f}%')
#
#     # Save the best model
#     if val_avg_accuracy > best_val_accuracy:
#         best_val_accuracy = val_avg_accuracy
#         torch.save(model.state_dict(), './checkpoint.pth')
#
# # Load the best model and evaluate on the test set
# model.load_state_dict(torch.load('./checkpoint.pth'))
# model.eval()
# test_total_loss = 0
# test_total_accuracy = 0
# test_total_top5_accuracy = 0
# with torch.no_grad():
#     for images, labels in test_loader:
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         test_total_loss += loss.item() * images.size(0)
#         test_total_accuracy += calculate_accuracy(outputs, labels).item() * images.size(0)
#         test_total_top5_accuracy += calculate_top5_accuracy(outputs, labels).item() * images.size(0)
#
# test_avg_loss = test_total_loss / len(test_loader.dataset)
# test_avg_accuracy = test_total_accuracy / len(test_loader.dataset)
# test_avg_top5_accuracy = test_total_top5_accuracy / len(test_loader.dataset)
#
# print(f'Test Loss: {test_avg_loss:.4f}, '
#       f'Test Accuracy: {test_avg_accuracy*100:.2f}%, '
#       f'Test Top-5 Accuracy: {test_avg_top5_accuracy*100:.2f}%')
#
# # Save the final model
# torch.save(model.state_dict(), './vit_mnist.pth')

# # Model summary (number of parameters)
# from torchsummary import summary
# summary(model, input_size=(1, 28, 28))
# Load the best model and evaluate on the test set
import time

model.load_state_dict(torch.load('./vit_mnist.pth'))
model.eval()
test_total_loss = 0
test_total_accuracy = 0
test_total_top5_accuracy = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        t0 = time.time()
        outputs = model(images)
        t1 = time.time()
        print(t1-t0)
        loss = criterion(outputs, labels)
        test_total_loss += loss.item() * images.size(0)
        test_total_accuracy += calculate_accuracy(outputs, labels).item() * images.size(0)
        test_total_top5_accuracy += calculate_top5_accuracy(outputs, labels).item() * images.size(0)

test_avg_loss = test_total_loss / len(test_loader.dataset)
test_avg_accuracy = test_total_accuracy / len(test_loader.dataset)
test_avg_top5_accuracy = test_total_top5_accuracy / len(test_loader.dataset)

print(f'Test Loss: {test_avg_loss:.4f}, '
      f'Test Accuracy: {test_avg_accuracy*100:.2f}%, '
      f'Test Top-5 Accuracy: {test_avg_top5_accuracy*100:.2f}%')