import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from sklearn.metrics import mutual_info_score, normalized_mutual_info_score, adjusted_mutual_info_score

import matplotlib.pyplot as plt
import random

import warnings

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

# 1. Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import argparse
from pytorch_msssim import ssim
import time
# # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# transform = transforms.Compose([transforms.ToTensor()])
# train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
#
# batch_size = 64
# train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='/home/gan/cifar_gan/data-cifar10', train=True, download=False, transform=transform_train)
train_loader = DataLoader(trainset, batch_size=100, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='/home/gan/cifar_gan/data-cifar10', train=False, download=False, transform=transform_test)
test_loader = DataLoader(testset, batch_size=100, shuffle=False)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define a more complex neural network with convolutional layers



def load_classifier(path):
    class ImprovedCNN(nn.Module):
        def __init__(self):
            super(ImprovedCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.5)
            self.fc1 = nn.Linear(512 * 2 * 2, 1024)
            self.fc2 = nn.Linear(1024, 512)
            self.fc3 = nn.Linear(512, 10)
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(128)
            self.bn3 = nn.BatchNorm2d(256)
            self.bn4 = nn.BatchNorm2d(512)

        def forward(self, x):
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.pool(F.relu(self.bn2(self.conv2(x))))
            x = self.pool(F.relu(self.bn3(self.conv3(x))))
            x = self.pool(F.relu(self.bn4(self.conv4(x))))
            x = x.view(-1, 512 * 2 * 2)
            x = self.dropout(F.relu(self.fc1(x)))
            x = self.dropout(F.relu(self.fc2(x)))
            x = self.fc3(x)
            return x

    classifier = ImprovedCNN()
    classifier.load_state_dict(torch.load(path))

    return classifier

def load_encoder_from_classifier(path):
    class ImprovedCNN(nn.Module):
        def __init__(self):
            super(ImprovedCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.5)
            self.fc1 = nn.Linear(512 * 2 * 2, 1024)
            self.fc2 = nn.Linear(1024, 512)
            self.fc3 = nn.Linear(512, 10)
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(128)
            self.bn3 = nn.BatchNorm2d(256)
            self.bn4 = nn.BatchNorm2d(512)

        def forward(self, x):
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.pool(F.relu(self.bn2(self.conv2(x))))
            x = self.pool(F.relu(self.bn3(self.conv3(x))))
            x = self.pool(F.relu(self.bn4(self.conv4(x))))
            x = x.view(-1, 512 * 2 * 2)
            # x = self.dropout(F.relu(self.fc1(x)))
            # x = self.dropout(F.relu(self.fc2(x)))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x.view(-1, 3, 32, 32)

    encoder = ImprovedCNN()
    # classifier.load_state_dict(torch.load(path))
    # fine-tune from the classifier
    # Load the state_dict while ignoring the final layers (assuming fc1, fc2, fc3 are the last ones to be modified)
    pretrained_dict = torch.load(MODEL_PATH)
    model_dict = encoder.state_dict()
    # Filter out fc layers from the pre-trained model
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'fc' not in k}
    # Update the model dict with pre-trained layers (except for fc layers)
    model_dict.update(pretrained_dict)

    # Load the updated state_dict back into the model
    encoder.load_state_dict(model_dict)

    # Now, replace the last fully connected layers
    # For example, if you're doing a new classification task with 5 classes
    encoder.fc1 = nn.Linear(512 * 2 * 2, 2048)
    encoder.fc2 = nn.Linear(2048, 2048)
    encoder.fc3 = nn.Linear(2048,  32*32 * 3)  # Changing this layer for new task with 5 classes

    # Optional: Freeze earlier layers if needed
    for param in encoder.conv1.parameters():
        param.requires_grad = False
    for param in encoder.conv2.parameters():
        param.requires_grad = False
    for param in encoder.conv3.parameters():
        param.requires_grad = False
    for param in encoder.conv4.parameters():
        param.requires_grad = False

    # # Define an optimizer, ensuring only the new layers are being trained
    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)


    return encoder

def classifier_validation(classifier, test_loader, device):
    classifier = classifier.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = classifier(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Classifier Accuracy on the test set: {100 * correct / total:.2f}%')


def build_encoder():
    class Encoder(nn.Module):
        def __init__(self):
            super(Encoder, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.5)
            self.fc1 = nn.Linear(512 * 2 * 2, 4096)
            self.fc2 = nn.Linear(4096, 4096)
            self.fc3 = nn.Linear(4096, 2048)
            self.fc4 = nn.Linear(2048, 32*32 * 3)
            # self.fc3 = nn.Linear(512, 10)
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(128)
            self.bn3 = nn.BatchNorm2d(256)
            self.bn4 = nn.BatchNorm2d(512)

        def forward(self, x):
            # x = self.pool(F.relu(self.bn1(self.conv1(x))))
            # x = self.pool(F.relu(self.bn2(self.conv2(x))))
            # x = self.pool(F.relu(self.bn3(self.conv3(x))))
            # x = self.pool(F.relu(self.bn4(self.conv4(x))))
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = self.pool(F.relu(self.conv4(x)))
            x = x.view(-1, 512 * 2 * 2)
            # x = self.dropout(F.relu(self.fc1(x)))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = F.relu(self.fc2(x))
            x = F.relu(self.fc3(x))
            x = self.fc4(x)
            # x = torch.re
            # print(x.shape)
            return x.view(-1, 3, 32, 32)

    encoder = Encoder()

    return encoder

def build_decoder():
    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.fc1 = nn.Linear(10, 30)
            self.fc2 = nn.Linear(30, 30)
            self.fc3 = nn.Linear(30, 10)


        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            # x = torch.relu(self.fc2(x))
            # x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    decoder = Decoder()
    return decoder

MODEL_PATH = '/home/gan/diffusion/cifar_diffusion/cifar_classifier/cifar_cnn.pth'
classifier = load_classifier(MODEL_PATH)
classifier_validation(classifier, test_loader, device)

# exit()



# encoder = build_encoder()
encoder = load_encoder_from_classifier(MODEL_PATH)
decoder = build_decoder()

# exit()

class Combined(nn.Module):
    def __init__(self, encoder, classifier, decoder):
        super(Combined, self).__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.decoder = decoder

        # Freeze the classifier's parameters
        for param in self.classifier.parameters():
            param.requires_grad = False

    def forward(self, x):
        # print(x.shape)
        x = self.encoder(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        x = self.decoder(x)
        # print(x.shape)
        return x


def ssim_loss(img1, img2):
    # return 1 - ssim(img1, img2, data_range=1, size_average=True)
    return ssim(img1, img2, data_range=1, size_average=True)

model = Combined(encoder, classifier, decoder).to(device)

# criterion = nn.CrossEntropyLoss()
#
# learning_rate = 0.001
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
learning_rate = 0.001
optimizer_all = optim.Adam(model.parameters(), lr=learning_rate)
optimizer_encoder = optim.Adam(encoder.parameters(), lr=learning_rate)

scheduler = StepLR(optimizer_all, step_size=400, gamma=0.01)


epochs = 1

def train():
    # Training loop
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):  # We don't need the target for an autoencoder
            optimizer_all.zero_grad()
            optimizer_encoder.zero_grad()

            data = data.to(device)
            output = model(data)

            classifier_output = classifier(data)
            image_generated = encoder(data)

            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            classifier_predict_output = classifier_predict_output.to(device)
            # print(output.shape)
            # print(classifier_predict_output.shape)
            # print(target.shape)
            # loss = criterion(output, classifier_predict_output)  # Compare the output with the input
            # loss.backward()
            # optimizer.step()
            CE_loss = criterion(output, classifier_predict_output)  # Compare the output with the input
            loss_all = CE_loss

            loss_ssim = ssim_loss(image_generated, data)
            loss_encoder = 0.001 * loss_ssim * loss_ssim

            loss_encoder.backward(retain_graph=True)

            # if epoch == 0:
            #     if loss_ssim > 0:
            #         loss_encoder = loss_ssim
            #
            #         loss_encoder.backward(retain_graph=True)






            loss_all.backward()

            optimizer_encoder.step()
            # optimizer_all.step()
            scheduler.step()

            if batch_idx % 100 == 0:
                # print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], CE_Loss: {CE_loss.item():.4f}, SSIM_Loss: {loss_ssim.item():.4f}')
                # test_combined()
                test_ind()


    torch.save(encoder.state_dict(), 'cifar10_encoder_ssim_2.pth')
    torch.save(decoder.state_dict(), 'cifar10_decoder_ssim_2.pth')

def test_combined():
    # Test the model
    model.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU

            t0 = time.time()
            output = encoder(data)
            t1 = time.time()
            print(t1-t0)

            # t0 = time.time()
            output = model(data)
            # t1 = time.time()
            # print(t0-t1)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct_acc += (predicted == target).sum().item()

            t0 = time.time()
            classifier_output = classifier(data)
            t1 = time.time()
            print(t0-t1)

            t0 = time.time()
            output1 = decoder(classifier_output)
            t1 = time.time()
            print(t1-t0)

            print("===========================")
            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            correct_fd += (predicted == classifier_predict_output).sum().item()


    print(f'Combined Accuracy on the test set: {100 * correct_acc / total:.2f}%')
    print(f'Combined Fidelity on the test set: {100 * correct_fd / total:.2f}%')

def test_ssim():
    cnt = 0
    total_mi = 0
    total_mi_ = 0
    for data in testset:
        image = data[0].to(device)
        ob_data = encoder(image.reshape(1,3,32,32))
        # ob_data = torch.randn_like(ob_data)
        mi = ssim(image.reshape(1,3,32,32), ob_data.reshape(1,3,32,32), data_range=1, size_average=True)
        # print(mi)
        cnt += 1
        total_mi_ += mi * mi

        total_mi += mi

        if cnt > 1000:
            break

    print(f"avg SSIM: {total_mi/float(cnt)}")
    print(f"avg SSIM square: {total_mi_/float(cnt)}")


def test_ind():
    # Test the model
    model.eval()
    encoder.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct_acc += (predicted == target).sum().item()

            # print(correct_acc)

            classifier_output = classifier(data)
            _, classifier_predict_output = torch.max(classifier_output.data, 1)
            correct_fd += (predicted == classifier_predict_output).sum().item()
            # print(correct_fd)
            # print(total)
            # print("=========")


    print(f'Combined Accuracy on the test set: {100 * correct_acc / total:.2f}%')
    print(f'Combined Fidelity on the test set: {100 * correct_fd / total:.2f}%')


def test_output_mi():
    # Test the model
    model.eval()
    encoder.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    true_outputs = []
    ob_outputs = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)

            true_outputs.append(predicted.detach().cpu())

            ob_data = encoder(data)
            ob_output = classifier(ob_data)
            _, ob_predict_output = torch.max(ob_output.data, 1)

            ob_outputs.append(ob_predict_output.detach().cpu())

    true = np.array(true_outputs)
    ob = np.array(ob_outputs)

    print(np.corrcoef(true, ob)[0, 1])
    print(np.corrcoef(true, ob)[0, 1] * np.corrcoef(true, ob)[0, 1])

def test_generation():
    # Test the model
    model.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU

            ob_data = encoder(data).detach().cpu()

            # Parameters
            num_samples = 10  # Number of images to sample and plot

            # Sampling indices
            # sample_indices = random.sample(range(ob_data.size(0)), num_samples)
            sample_indices = [0, 1, 2 ,3 ,4, 5, 6, 7, 8, 9]
            # Creating subplots
            fig, axes = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))

            for i, idx in enumerate(sample_indices):
                img = ob_data[idx]
                img = img.permute(1, 2, 0)  # Change shape from (3, 32, 32) to (32, 32, 3)
                img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]

                axes[i].imshow(img)
                axes[i].axis('off')
                # axes[i].set_title(f'Index: {idx}')

            # Save the plot
            plt.tight_layout()
            plt.savefig('./sampled_images.png')
            plt.savefig('./sampled_images.pdf')
            plt.show()
            exit()



# train()

# encoder.load_state_dict(torch.load("./cifar10_encoder.pth"))
# decoder.load_state_dict(torch.load("./cifar10_decoder.pth"))
encoder.load_state_dict(torch.load("./cifar10_encoder_fd.pth"))
decoder.load_state_dict(torch.load("./cifar10_decoder_fd.pth"))
# encoder.load_state_dict(torch.load("./cifar10_encoder_ssim_2.pth"))
# decoder.load_state_dict(torch.load("./cifar10_decoder_ssim_2.pth"))

# classifier_validation(classifier, test_loader, device)
# test_combined()
# test_ssim()
# test_ind()
# test_output_mi()
test_generation()




