import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Subset
import random
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"






class CNN(nn.Module):
    def __init__(self, output_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=3,
                               kernel_size=5)

        self.conv2 = nn.Conv2d(in_channels=3,
                               out_channels=6,
                               kernel_size=5)

        self.fc_1 = nn.Linear(6 * 4 * 4, 128)
        self.fc_2 = nn.Linear(128, 64)
        self.fc_3 = nn.Linear(64, output_dim)

    def forward(self, x):

        # x = [batch size, 1, 28, 28]
        x = self.conv1(x)

        # x = [batch size, 6, 24, 24]
        x = F.max_pool2d(x, kernel_size=2)

        # x = [batch size, 6, 12, 12]
        x = F.relu(x)

        x = self.conv2(x)

        # x = [batch size, 12, 8, 8]
        x = F.max_pool2d(x, kernel_size=2)

        # x = [batch size, 12, 4, 4]
        x = F.relu(x)

        x = x.view(x.shape[0], -1)

        # x = [batch size, 16*4*4 = 256]
        x = self.fc_1(x)
        #
        # # x = [batch size, 120]
        x = F.relu(x)

        x = self.fc_2(x)

        # x = batch size, 84]

        x = F.relu(x)

        x = self.fc_3(x)

        # x = [batch size, output dim]

        return x



def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        model.train()
        X = X.to(device)
        y = y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # plt.imshow(X[1,0].cpu().detach(), cmap='gray')
        # plt.show()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_epoch(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    accu = (100*correct)
    eval_losses.append(test_loss)
    eval_accu.append(accu)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


def filter_classes(dataset, classes=(0, 1), samples_per_class=100):
    targets = dataset.targets
    class_indices = []

    for class_label in classes:
        indices = torch.nonzero(targets == class_label).squeeze()
        class_indices.append(indices[:samples_per_class])  # Take only the first `samples_per_class` samples

    return torch.cat(class_indices)  # Concatenate indices of both classes

class AddCustomNoise:
    def __init__(self, noise_scale = 1):
        self.noise_scale = noise_scale

    def __call__(self, image):
        # Generate noise
        noise_matrix = torch.randn_like(image)
        noise_matrix_boundary = noise_matrix
        noise_matrix_boundary[..., 5:23, 5:23] = 0  # Set the center region to 0

        # Add noise to the scaled image
        image = image * self.noise_scale + 0.5*torch.abs(noise_matrix_boundary)

        # Return the modified image as a Tensor
        return image


if __name__ == '__main__':

    seed = 100
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


    d = 243  # dimension for each patch
    m = 100  # Number of weight vectors
    time = 0.2 # diffusion time

    alpha = np.exp(-time)
    beta = np.sqrt(1 - np.exp(-2 * time))
    print(alpha, beta)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    noise_scale = 0.1

    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            AddCustomNoise(noise_scale)
            # transforms.Normalize((0.5,), (0.5,))
        ])
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            AddCustomNoise(noise_scale)
            # transforms.Normalize((0.5,), (0.5,))
        ])
    train_dataset = torchvision.datasets.MNIST(root="data", train=True, transform=train_transform, download=True)
    test_dataset = torchvision.datasets.MNIST(root="data", train=False, transform=test_transform, download=True)

    # Get the indices for classes 0 and 1 in training and test sets
    train_indices = filter_classes(train_dataset, samples_per_class=10)
    test_indices = filter_classes(test_dataset, samples_per_class=-1)

    # Create subsets of the dataset
    train_dataset = Subset(train_dataset, train_indices)
    test_dataset = Subset(test_dataset, test_indices)

    # train_dataset.data = train_dataset.data.type(torch.float32)
    #
    # test_dataset.data = test_dataset.data.type(torch.float32)


    # for i in range(len(train_dataset.data)):
    #     noise_matrix = noise_level* np.random.normal(0, 1, (28, 28))
    #     noise_matrix_boundary = noise_matrix
    #     noise_matrix_boundary[5:23, 5:23] = 0
    #     train_dataset.data[i] = train_dataset.data[i] + np.abs(noise_matrix_boundary)
    #
    # train_dataset.data = train_dataset.data.type(torch.uint8)
    #
    # for i in range(len(test_dataset.data)):
    #     noise_matrix = noise_level * np.random.normal(0, 1, (28, 28))
    #     noise_matrix_boundary = noise_matrix
    #     noise_matrix_boundary[5:23, 5:23] = 0
    #     test_dataset.data[i] = test_dataset.data[i] + np.abs(noise_matrix_boundary) * 255
    #
    # test_dataset.data = test_dataset.data.type(torch.uint8)

    # plt.imshow(train_dataset[1][0][0], cmap='gray')
    # plt.colorbar()
    # plt.show()

    #### train #####
    model = CNN(2).to(device)

    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25], gamma=0.1)

    loss_fn = nn.CrossEntropyLoss()

    eval_losses = []
    eval_accu = []
    epochs = 200
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        test_epoch(test_dataloader, model, loss_fn)
        # scheduler.step()
    print("Done!")