import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from sklearn.metrics import mutual_info_score, normalized_mutual_info_score, adjusted_mutual_info_score
from pytorch_msssim import ssim
import matplotlib.pyplot as plt
import time
import random

# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define a more complex neural network with convolutional layers



def load_classifier(path):
    class MLP(nn.Module):
        def __init__(self):
            super(MLP, self).__init__()

            self.fc1 = nn.Linear(784, 500)
            self.fc2 = nn.Linear(500, 300)
            self.fc3 = nn.Linear(300, 100)
            self.fc4 = nn.Linear(100, 10)


        def forward(self, x):

            x = x.view(-1, 784)  # Flatten the tensor
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = torch.relu(self.fc3(x))
            x = self.fc4(x)
            return x

    classifier = MLP()
    classifier.load_state_dict(torch.load(path))

    return classifier


def classifier_validation(classifier, test_loader, device):
    classifier = classifier.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = classifier(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Classifier Accuracy on the test set: {100 * correct / total:.2f}%')


def build_encoder():
    class Encoder(nn.Module):
        def __init__(self):
            super(Encoder, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            self.fc1 = nn.Linear(128 * 3 * 3, 784)
            # self.fc2 = nn.Linear(256, 128)
            # self.fc3 = nn.Linear(128, 10)
            # self.dropout = nn.Dropout(0.5)

        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = self.pool(torch.relu(self.conv2(x)))
            x = self.pool(torch.relu(self.conv3(x)))
            x = x.view(-1, 128 * 3 * 3)  # Flatten the tensor
            # x = torch.relu(self.fc1(x))
            # x = self.dropout(x)
            # x = torch.relu(self.fc2(x))
            x = self.fc1(x)
            x = x.view(-1, 1, 28, 28)
            return x

    encoder = Encoder()

    return encoder

def build_decoder():
    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.fc1 = nn.Linear(10, 64)
            self.fc2 = nn.Linear(64, 64)
            self.fc3 = nn.Linear(64, 10)


        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    decoder = Decoder()
    return decoder

classifier = load_classifier("./mnist_classifier.pth")
# classifier_validation(classifier, test_loader, device)

encoder = build_encoder()
decoder = build_decoder()

class Combined(nn.Module):
    def __init__(self, encoder, classifier, decoder):
        super(Combined, self).__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.decoder = decoder

        # Freeze the classifier's parameters
        for param in self.classifier.parameters():
            param.requires_grad = False

    def forward(self, x):
        # print(x.shape)
        x = self.encoder(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        x = self.decoder(x)
        # print(x.shape)
        return x

# SSIM loss function
def ssim_loss(img1, img2):
    # return 1 - ssim(img1, img2, data_range=1, size_average=True)
    return ssim(img1, img2, data_range=1, size_average=True)



model = Combined(encoder, classifier, decoder).to(device)

criterion = nn.CrossEntropyLoss()

learning_rate = 0.001
optimizer_all = optim.Adam(model.parameters(), lr=learning_rate)
optimizer_encoder = optim.Adam(encoder.parameters(), lr=learning_rate)

epochs = 10

def train():
    # Training loop
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):  # We don't need the target for an autoencoder
            optimizer_all.zero_grad()
            optimizer_encoder.zero_grad()

            data = data.to(device)
            output = model(data)

            classifier_output = classifier(data)
            image_generated = encoder(data)

            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            classifier_predict_output = classifier_predict_output.to(device)

            CE_loss = criterion(output, classifier_predict_output)  # Compare the output with the input
            loss_all = CE_loss



            loss_ssim = ssim_loss(image_generated, data)


            loss_encoder = 0.001 * loss_ssim * loss_ssim

            loss_encoder.backward(retain_graph=True)


            loss_all.backward()

            optimizer_encoder.step()
            optimizer_all.step()


            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], CE_Loss: {CE_loss.item():.4f}, SSIM_Loss: {loss_ssim.item():.4f}')

    torch.save(encoder.state_dict(), 'mnist_encoder_ssim_2.pth')
    torch.save(decoder.state_dict(), 'mnist_decoder_ssim_2.pth')

def test_combined():
    # Test the model
    model.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU


            # t0 = time.time()
            # output = encoder(data)
            # t1 = time.time()
            # print(t1-t0)


            output = model(data)

            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct_acc += (predicted == target).sum().item()

            # t0 = time.time()
            classifier_output = classifier(data)
            # t1 = time.time()
            # print(t1-t0)

            # t0 = time.time()
            # output1 = decoder(classifier_output)
            # t1 = time.time()
            # print(t1-t0)
            # print("====================")
            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            correct_fd += (predicted == classifier_predict_output).sum().item()


    print(f'Combined Accuracy on the test set: {100 * correct_acc / total:.2f}%')
    print(f'Combined Fidelity on the test set: {100 * correct_fd / total:.2f}%')

def test_ssim():
    cnt = 0
    total_mi = 0
    total_mi_ = 0
    for data in test_dataset:
        image = data[0].to(device)
        ob_data = encoder(image)

        mi = ssim(image.reshape(1,1,28,28), ob_data.reshape(1,1,28,28), data_range=1, size_average=True)
        # print(mi)
        cnt += 1
        total_mi_ += mi * mi
        total_mi += mi
    print(f"avg SSIM: {total_mi/float(cnt)}")
    print(f"avg SSIM square: {total_mi_/float(cnt)}")

def test_output_mi():
    # Test the model
    model.eval()
    encoder.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    true_outputs = []
    ob_outputs = []
    cnt = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)

            true_outputs.append(predicted.detach().cpu())

            ob_data = encoder(data)
            ob_output = classifier(ob_data)
            _, ob_predict_output = torch.max(ob_output.data, 1)

            ob_outputs.append(ob_predict_output.detach().cpu())

            cnt += 1

            if cnt == 156:
                break
            # print(cnt)

    true = np.array(true_outputs)
    ob = np.array(ob_outputs)

    print(np.corrcoef(true, ob)[0, 1])
    print(np.corrcoef(true, ob)[0, 1] * np.corrcoef(true, ob)[0, 1])

def test_generation():
    # Test the model
    model.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU

            ob_data = encoder(data).detach().cpu()

            # Parameters
            num_samples = 10  # Number of images to sample and plot

            # Sampling indices
            # sample_indices = random.sample(range(ob_data.size(0)), num_samples)
            sample_indices = [0, 1, 2 ,3 ,4, 5, 6, 7, 8, 9]
            # Creating subplots
            fig, axes = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))

            for i, idx in enumerate(sample_indices):
                img = ob_data[idx]
                img = img.permute(1, 2, 0)  # Change shape from (3, 32, 32) to (32, 32, 3)
                img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]

                axes[i].imshow(img, cmap='Greys',  interpolation='nearest')
                axes[i].axis('off')
                # axes[i].set_title(f'Index: {idx}')

            # Save the plot
            plt.tight_layout()
            plt.savefig('./sampled_images.png')
            plt.savefig('./sampled_images.pdf')
            plt.show()
            exit()

#
# train()
#
encoder.load_state_dict(torch.load("./mnist_encoder_ssim_2.pth"))
decoder.load_state_dict(torch.load("./mnist_decoder_ssim_2.pth"))

# classifier_validation(classifier, test_loader, device)
# test_combined()
# test_ssim()
# test_output_mi()
test_generation()

