import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from sklearn.metrics import mutual_info_score, normalized_mutual_info_score, adjusted_mutual_info_score
from pytorch_msssim import ssim
import matplotlib.pyplot as plt
import time

# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define a more complex neural network with convolutional layers



def load_classifier(path):
    # Define the MLP module
    class MLP(nn.Module):
        def __init__(self, in_features, hidden_units, dropout_rate):
            super(MLP, self).__init__()
            layers = []
            for units in hidden_units:
                layers.append(nn.Linear(in_features, units))
                layers.append(nn.GELU())
                layers.append(nn.Dropout(dropout_rate))
                in_features = units
            self.mlp = nn.Sequential(*layers)

        def forward(self, x):
            return self.mlp(x)

    # Define the Patches module
    class Patches(nn.Module):
        def __init__(self, patch_size):
            super(Patches, self).__init__()
            self.patch_size = patch_size

        def forward(self, images):
            batch_size, channels, height, width = images.size()
            patch_size = self.patch_size
            patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
            patches = patches.contiguous().view(batch_size, channels, -1, patch_size, patch_size)
            patches = patches.permute(0, 2, 1, 3, 4)
            patch_dim = channels * patch_size * patch_size
            patches = patches.contiguous().view(batch_size, -1, patch_dim)
            return patches

    # Define the Patch Encoder module
    class PatchEncoder(nn.Module):
        def __init__(self, num_patches, projection_dim, patch_dim):
            super(PatchEncoder, self).__init__()
            self.num_patches = num_patches
            self.projection = nn.Linear(patch_dim, projection_dim)
            self.position_embedding = nn.Embedding(num_patches, projection_dim)

        def forward(self, patches):
            positions = torch.arange(0, self.num_patches, device=patches.device).unsqueeze(0)
            encoded = self.projection(patches) + self.position_embedding(positions)
            return encoded

    # Define the Transformer Block
    class TransformerBlock(nn.Module):
        def __init__(self, dim, num_heads, mlp_dim, dropout_rate):
            super(TransformerBlock, self).__init__()
            self.layernorm1 = nn.LayerNorm(dim, eps=1e-6)
            self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout_rate)
            self.layernorm2 = nn.LayerNorm(dim, eps=1e-6)
            self.mlp = nn.Sequential(
                nn.Linear(dim, mlp_dim),
                nn.GELU(),
                nn.Dropout(dropout_rate),
                nn.Linear(mlp_dim, dim),
                nn.Dropout(dropout_rate),
            )
            self.dropout = nn.Dropout(dropout_rate)

        def forward(self, x):
            x_norm = self.layernorm1(x)
            x_norm = x_norm.permute(1, 0, 2)  # Required shape for MultiheadAttention: (seq_len, batch_size, embed_dim)
            attn_output, _ = self.mha(x_norm, x_norm, x_norm)
            attn_output = attn_output.permute(1, 0, 2)
            x = x + self.dropout(attn_output)
            x_norm = self.layernorm2(x)
            x = x + self.mlp(x_norm)
            return x

    # Define the Vision Transformer model
    class ViT(nn.Module):
        def __init__(self, image_size=28, patch_size=14, num_classes=10, dim=96,
                     depth=16, heads=4, mlp_dim=2048, dropout_rate=0.1):
            super(ViT, self).__init__()
            assert image_size % patch_size == 0, "Image size must be divisible by patch size."

            self.num_patches = (image_size // patch_size) ** 2
            patch_dim = 1 * patch_size * patch_size  # Since MNIST images have 1 channel
            self.patches = Patches(patch_size)
            self.patch_encoder = PatchEncoder(self.num_patches, dim, patch_dim)

            self.transformer = nn.ModuleList([
                TransformerBlock(dim=dim, num_heads=heads, mlp_dim=dim*2, dropout_rate=dropout_rate)
                for _ in range(depth)
            ])

            self.layernorm = nn.LayerNorm(dim, eps=1e-6)
            self.mlp_head = nn.Sequential(
                nn.Flatten(),
                nn.Dropout(0.5),
                MLP(in_features=dim * self.num_patches, hidden_units=[2048, 1024], dropout_rate=0.5),
                nn.Linear(1024, num_classes),
            )

        def forward(self, x):
            patches = self.patches(x)
            x = self.patch_encoder(patches)

            for transformer_block in self.transformer:
                x = transformer_block(x)

            x = self.layernorm(x)
            logits = self.mlp_head(x)
            return logits

    # Instantiate the model, define the optimizer and loss function
    classifier = ViT().to(device)
    classifier.load_state_dict(torch.load(path))

    return classifier


def classifier_validation(classifier, test_loader, device):
    classifier = classifier.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = classifier(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Classifier Accuracy on the test set: {100 * correct / total:.2f}%')


def build_encoder():
    class Encoder(nn.Module):
        def __init__(self):
            super(Encoder, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            self.fc1 = nn.Linear(128 * 3 * 3, 784)
            # self.fc2 = nn.Linear(256, 128)
            # self.fc3 = nn.Linear(128, 10)
            # self.dropout = nn.Dropout(0.5)

        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = self.pool(torch.relu(self.conv2(x)))
            x = self.pool(torch.relu(self.conv3(x)))
            x = x.view(-1, 128 * 3 * 3)  # Flatten the tensor
            # x = torch.relu(self.fc1(x))
            # x = self.dropout(x)
            # x = torch.relu(self.fc2(x))
            x = self.fc1(x)
            x = x.view(-1, 1, 28, 28)
            return x

    encoder = Encoder()

    return encoder

def build_decoder():
    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.fc1 = nn.Linear(10, 64)
            self.fc2 = nn.Linear(64, 64)
            self.fc3 = nn.Linear(64, 10)


        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    decoder = Decoder()
    return decoder

classifier = load_classifier("./vit_mnist.pth")
# classifier_validation(classifier, test_loader, device)

encoder = build_encoder()
decoder = build_decoder()

class Combined(nn.Module):
    def __init__(self, encoder, classifier, decoder):
        super(Combined, self).__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.decoder = decoder

        # Freeze the classifier's parameters
        for param in self.classifier.parameters():
            param.requires_grad = False

    def forward(self, x):
        # print(x.shape)
        x = self.encoder(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        x = self.decoder(x)
        # print(x.shape)
        return x

# SSIM loss function
def ssim_loss(img1, img2):
    # return 1 - ssim(img1, img2, data_range=1, size_average=True)
    return ssim(img1, img2, data_range=1, size_average=True)



model = Combined(encoder, classifier, decoder).to(device)

criterion = nn.CrossEntropyLoss()

learning_rate = 0.001
optimizer_all = optim.Adam(model.parameters(), lr=learning_rate)
optimizer_encoder = optim.Adam(encoder.parameters(), lr=learning_rate)

epochs = 10

def train():
    # Training loop
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):  # We don't need the target for an autoencoder
            optimizer_all.zero_grad()
            optimizer_encoder.zero_grad()

            data = data.to(device)
            output = model(data)

            classifier_output = classifier(data)
            image_generated = encoder(data)

            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            classifier_predict_output = classifier_predict_output.to(device)

            CE_loss = criterion(output, classifier_predict_output)  # Compare the output with the input
            loss_all = CE_loss



            loss_ssim = ssim_loss(image_generated, data)


            loss_encoder = 0.001 * loss_ssim * loss_ssim

            loss_encoder.backward(retain_graph=True)


            loss_all.backward()

            optimizer_encoder.step()
            optimizer_all.step()


            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], CE_Loss: {CE_loss.item():.4f}, SSIM_Loss: {loss_ssim.item():.4f}')

    torch.save(encoder.state_dict(), 'mnist_encoder_ssim_2.pth')
    torch.save(decoder.state_dict(), 'mnist_decoder_ssim_2.pth')

def test_combined():
    # Test the model
    model.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU

            # t0 = time.time()
            # output = encoder(data)
            # t1 = time.time()
            # print(t1-t0)

            # t0 = time.time()
            output = model(data)
            # t1 = time.time()
            # print(t1-t0)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct_acc += (predicted == target).sum().item()

            # t0 = time.time()
            classifier_output = classifier(data)
            # t1 = time.time()
            # print(t1-t0)

            # t0 = time.time()
            # output1 = decoder(classifier_output)
            # t1 = time.time()
            # print(t1-t0)
            # print("====================")
            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            correct_fd += (predicted == classifier_predict_output).sum().item()


    print(f'Combined Accuracy on the test set: {100 * correct_acc / total:.2f}%')
    print(f'Combined Fidelity on the test set: {100 * correct_fd / total:.2f}%')

def test_ssim():
    cnt = 0
    total_mi = 0
    total_mi_ = 0
    for data in test_dataset:
        image = data[0].to(device)
        ob_data = encoder(image)

        mi = ssim(image.reshape(1,1,28,28), ob_data.reshape(1,1,28,28), data_range=1, size_average=True)
        # print(mi)
        cnt += 1
        total_mi_ += mi * mi
        total_mi += mi
    print(f"avg SSIM: {total_mi/float(cnt)}")
    print(f"avg SSIM square: {total_mi_/float(cnt)}")

def test_output_mi():
    # Test the model
    model.eval()
    encoder.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    true_outputs = []
    ob_outputs = []
    cnt = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)

            true_outputs.append(predicted.detach().cpu())

            ob_data = encoder(data)
            ob_output = classifier(ob_data)
            _, ob_predict_output = torch.max(ob_output.data, 1)

            ob_outputs.append(ob_predict_output.detach().cpu())

            cnt += 1

            if cnt == 156:
                break
            # print(cnt)

    true = np.array(true_outputs)
    ob = np.array(ob_outputs)

    print(np.corrcoef(true, ob)[0, 1])
    print(np.corrcoef(true, ob)[0, 1] * np.corrcoef(true, ob)[0, 1])
#
# train()
#
encoder.load_state_dict(torch.load("./mnist_encoder_ssim_2.pth"))
decoder.load_state_dict(torch.load("./mnist_decoder_ssim_2.pth"))

# classifier_validation(classifier, test_loader, device)
test_combined()
test_ssim()
test_output_mi()


