import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from torch.utils.data import DataLoader, Subset
import random
import os
from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib.ticker import MaxNLocator

from mpl_toolkits.axes_grid1 import make_axes_locatable

import argparse
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

color_noise = 'tab:blue'
color_feature = 'lightcoral'
color_loss = 'tab:blue'
color_acc_train = 'tab:blue'
color_acc_test = 'lightcoral'
def filter_classes(dataset, classes=(0, 1), samples_per_class=100):
    targets = dataset.targets
    class_indices = []

    for class_label in classes:
        indices = torch.nonzero(targets == class_label).squeeze()
        class_indices.append(indices[:samples_per_class])  # Take only the first `samples_per_class` samples

    return torch.cat(class_indices)  # Concatenate indices of both classes

class NormalizeAndSegmentVerticallyTransform:
    def __init__(self):
        # Normalization parameters (single-channel normalization for grayscale)
        self.normalize = transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale image with mean=0.5, std=0.5


    def __call__(self, img):
        img = transforms.ToTensor()(img)  # Convert image to tensor
        img = self.normalize(img)  # Apply normalization

        noise_patch = torch.randn_like(img)

        patches = [img * signal_scale, noise_patch * noise_scale]
        patches = torch.cat(patches, dim=0)
        return patches.view(2, -1)


class CNN_diff(nn.Module):
    def __init__(self, d, m, act='quad'):
        super(CNN_diff, self).__init__()
        self.act = act
        self.m = m
        self.W = nn.Parameter(torch.randn(m, d) * 0.01)

    def activation(self, x):
        if self.act == 'relu':
            return F.relu(x)
        elif self.act == 'quad':
            return x ** 2
        else:
            raise NotImplementedError

    def forward(self, x):
        # x is of shape [batch_size, *, num_patches, d]
        inner_products = torch.einsum('...nd,md->...nm', x, self.W)
        inner_products = self.activation(inner_products)
        output = torch.einsum('...nm,md->...nd', inner_products, self.W) / np.sqrt(self.m)
        return output


class CNN_class(nn.Module):
    def __init__(self, d, m, act='quad'):
        super(CNN_class, self).__init__()
        self.m = m
        self.act = act
        self.Wp = nn.Parameter(torch.randn(m, d) * 0.01)
        self.Wn = nn.Parameter(torch.randn(m, d) * 0.01)

    def activation(self, x):
        if self.act == 'relu':
            return F.relu(x)
        elif self.act == 'quad':
            return x ** 2
        else:
            raise NotImplementedError

    def forward(self, x):
        # x is of shape [batch_size, num_patches, d]
        inner_products_p = torch.einsum('...nd,md->...nm', x, self.Wp)
        inner_products_n = torch.einsum('...nd,md->...nm', x, self.Wn)
        inner_products_p = self.activation(inner_products_p)
        inner_products_n = self.activation(inner_products_n)
        output = inner_products_p.mean(-1).sum(-1) - inner_products_n.mean(-1).sum(-1) # mean over neurons and sum over patch
        return output



def train():
    # train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=False)
    # test_loader = DataLoader(dataset=test_dataset, batch_size=batchsize, shuffle=False)

    if args.model == 'cnn':
        model = CNN_class(d=d, m=m)
        lr = 0.5
        epochs = max_epoch_cnn
    elif args.model == 'diff':
        model = CNN_diff(d=d, m=m)
        lr = 0.5
        epochs = max_epoch_diff
    else:
        raise NotImplementedError

    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    if args.model == 'diff':
        losses = []
        noise_memorization = np.zeros((m, 100, max_epoch_diff))
        feature_learning = np.zeros((m, 100, max_epoch_diff))

        for epoch in range(epochs):
            model.train()
            eps = torch.randn(train_images.shape[0],n_eps,train_images.shape[1],train_images.shape[2]).to(device)
            noisy_images = alpha * train_images.unsqueeze(1) + beta * eps
            # eps = torch.randn_like(train_images)
            # noisy_images = alpha * train_images + beta * eps
            outputs = model(noisy_images)

            optimizer.zero_grad()
            loss = F.mse_loss(outputs, eps)

            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            model.eval()
            with torch.no_grad():
                feature_learning[:, :, epoch] = (torch.matmul(model.W, train_images[:, 0, :].T)).cpu().detach().numpy()
                noise_memorization[:, :, epoch] = (torch.matmul(model.W, train_images[:, 1, :].T)).cpu().detach().numpy()

            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {losses[-1]:.4f}")

        feature_learning = np.max(np.abs(feature_learning), axis=0)
        feature_learning = np.mean(feature_learning, axis=0)

        noise_memorization = np.max(np.abs(noise_memorization), axis=0)
        noise_memorization = np.mean(noise_memorization, axis=0)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': losses,
            'feature_learning': feature_learning,
            'noise_memorization': noise_memorization
        }

        torch.save(checkpoint, f'diff_model_checkpoint_{signal_scale}_{time}.pth')

    elif args.model == 'cnn':
        losses = []
        test_accs = []
        train_accs = []
        noise_memorization_p = np.zeros((m, 100, max_epoch_cnn))
        feature_learning_p = np.zeros((m, 100, max_epoch_cnn))
        noise_memorization_n = np.zeros((m, 100, max_epoch_cnn))
        feature_learning_n = np.zeros((m, 100, max_epoch_cnn))

        for epoch in range(epochs):
            model.train()

            f_pred = model(train_images)
            loss = torch.log(torch.add(torch.exp(-f_pred * train_labels), 1)).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            # train test accuracy
            model.eval()
            with torch.no_grad():
                feature_learning_p[:, :, epoch] = (torch.matmul(model.Wp, train_images[:,0,:].T)).cpu().detach().numpy()
                feature_learning_n[:, :, epoch] = (torch.matmul(model.Wn, train_images[:, 0, :].T)).cpu().detach().numpy()
                noise_memorization_p[:, :, epoch] = (torch.matmul(model.Wp, train_images[:,1,:].T)).cpu().detach().numpy()
                noise_memorization_n[:, :, epoch] = (torch.matmul(model.Wn, train_images[:,1,:].T)).cpu().detach().numpy()
                train_acc = calculate_accuracy(train_images, train_labels, model)
                test_acc = calculate_accuracy(test_images, test_labels, model)
                train_accs.append(train_acc)
                test_accs.append(test_acc)

            print(
                f"Epoch [{epoch + 1}/{epochs}], Loss: {losses[-1]:.4f}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}")

        feature_learning_p = np.max(np.abs(feature_learning_p), axis=0)
        feature_learning_n = np.max(np.abs(feature_learning_n), axis=0)
        feature_learning = np.maximum(feature_learning_p, feature_learning_n)
        feature_learning = np.mean(feature_learning, axis=0)

        noise_memorization_p = np.max(np.abs(noise_memorization_p), axis=0)
        noise_memorization_n = np.max(np.abs(noise_memorization_n), axis=0)
        noise_memorization = np.maximum(noise_memorization_p, noise_memorization_n)
        noise_memorization = np.mean(noise_memorization, axis=0)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': losses,
            'test_acc': test_accs,
            'train_acc': train_accs,
            'feature_learning': feature_learning,
            'noise_memorization': noise_memorization
        }

        torch.save(checkpoint, f'class_model_checkpoint_{signal_scale}.pth')

def calculate_accuracy(images, labels, model):
    # take in images = [N, 2, 786]
    # labels = [N] of -1 and +1
    model.eval()
    total = images.shape[0]
    with torch.no_grad():
        f_pred = model(images)
        correct = ((labels * f_pred) > 0).sum().item()
    # correct = 0
    # total = 0
    # model.eval()  # Set the model to evaluation mode
    # with torch.no_grad():  # Disable gradient computation for evaluation
    #     for images, labels in loader:
    #         images = images.to(device)
    #         labels = labels.to(device)
    #
    #         f_pred = model(images)
    #         total += labels.size(0)
    #         pred_binary = (f_pred > 0).float()
    #         # pred_binary = torch.argmax(f_pred, dim=-1)
    #         correct += (pred_binary == labels).sum().item()

    return correct / total


def visualize_images(image):
    # image contain # [2, 784]
    if len(image.shape) > 2:
        image = image[0]

    image = image.view(2, 28, 28)
    image[0] = image[0] * 0.5 + 0.5
    image = torch.cat((image[0, ...], image[1, ...]), dim=0)

    plt.imshow(image, cmap='gray')
    plt.colorbar()
    plt.show()

def reverse_transform(image):
    assert len(image.shape) == 2
    image_copy = image.detach().clone()
    image_copy = image_copy.view(2, 28, 28)
    image_copy[0] = image_copy[0] * 0.5 + 0.5
    image_copy = torch.cat((image_copy[0, ...], image_copy[1, ...]), dim=0)
    return image_copy

def visualize_classification(images, label):
    checkpoint_class = torch.load(f'class_model_checkpoint_{signal_scale}.pth', map_location='cpu')
    model = CNN_class(d=d, m=m)
    model.load_state_dict(checkpoint_class['model_state_dict'])

    # # == visualize through input gradient ==
    model.zero_grad()
    Wp = model.Wp.data.clone()
    Wn = model.Wn.data.clone()
    images = images.detach().requires_grad_(True)
    # my_param = torch.nn.Parameter(images.clone())
# my_param = torch.nn.Parameter(images)

    inner_products_p = torch.einsum('...nd,md->...nm', images, Wp)
    inner_products_n = torch.einsum('...nd,md->...nm', images, Wn)
    inner_products_p = model.activation(inner_products_p)
    inner_products_n = model.activation(inner_products_n)

    if label == 0:
        # correspond to negative class
        output = inner_products_n.mean(-1).sum(-1)
    elif label == 1:
        # correspond to positive class
        output = inner_products_p.mean(-1).sum(-1)

    # output = model(my_param)
    grad = torch.autograd.grad(output, images)
    image = grad[0]
    image = image.view(2,28,28)
    image = torch.cat((image[0, ...], image[1, ...]), dim=0).cpu().detach().numpy()
    # image = reverse_transform(grad[0]).detach().cpu().numpy()

    return image



def visualize_diffusion(model, image):
    eps = torch.randn_like(image)
    noisy_image = alpha * image + beta * eps
    noise_pred = model(noisy_image)
    x0_pred = (image - beta * noise_pred)/alpha
    image = reverse_transform(x0_pred)
    return image.cpu().detach().numpy()



def visualize_all():
    print(f"visualizing signal scale = {signal_scale}, diffusion time = {time}")

    norm_map = {'0.1': [mpl.colors.Normalize(vmin=-2, vmax=2),
                      mpl.colors.Normalize(vmin=-0.017, vmax=0.017),
                      mpl.colors.Normalize(vmin=-2, vmax=2)] ,
                '0.5': [mpl.colors.Normalize(vmin=-3.5, vmax=3.5),
                        mpl.colors.Normalize(vmin=-0.22, vmax=0.22),
                        mpl.colors.Normalize(vmin=-4.3, vmax=4.3)],
                }

    checkpoint_class = torch.load(f'class_model_checkpoint_{signal_scale}.pth', map_location='cpu')
    model_class = CNN_class(d=d, m=m)
    model_class.load_state_dict(checkpoint_class['model_state_dict'])

    checkpoint_diff = torch.load(f'diff_model_checkpoint_{signal_scale}_{time}.pth', map_location='cpu')
    model_diff = CNN_diff(d=d, m=m)
    model_diff.load_state_dict(checkpoint_diff['model_state_dict'])

    # Classification and diffusion plot
    # fig = plt.figure(figsize=(10,9))
    # grid = AxesGrid(fig, 111,  # similar to subplot(122)
    #                 nrows_ncols=(3, len(indices)),
    #                 axes_pad=0.2,
    #                 label_mode="1",
    #                 share_all=True,
    #                 cbar_location="right",
    #                 cbar_mode="edge",
    #                 # cbar_mode="each",
    #                 cbar_size="7%",
    #                 cbar_pad="4%",
    #                 )
    #
    # for i, idx in enumerate(indices):
    #     images, labels = test_dataset[idx]
    #     original_image = reverse_transform(images).detach().cpu().numpy()
    #
    #     class_visual = visualize_classification(images, labels)
    #     diff_visual = visualize_diffusion(model_diff, images)
    #
    #     if str(signal_scale) in norm_map:
    #         im1 = grid[i].imshow(original_image, cmap='RdBu', norm=norm_map[str(signal_scale)][0])
    #         im2 = grid[i+len(indices)].imshow(class_visual, cmap='RdBu', norm=norm_map[str(signal_scale)][1])
    #         im3 = grid[i+2*len(indices)].imshow(diff_visual, cmap='RdBu', norm=norm_map[str(signal_scale)][2])
    #
    #         if i == len(indices) - 1:
    #             grid.cbar_axes[0].colorbar(im1)
    #             grid.cbar_axes[1].colorbar(im2)
    #             grid.cbar_axes[2].colorbar(im3)
    #
    #     else:
    #         im1 = grid[i].imshow(original_image, cmap='RdBu', )
    #         im2 = grid[i + len(indices)].imshow(class_visual, cmap='RdBu',)
    #         im3 = grid[i + 2 * len(indices)].imshow(diff_visual, cmap='RdBu', )
    #         grid.cbar_axes[i].colorbar(im1)
    #         grid.cbar_axes[i + len(indices)].colorbar(im2)
    #         grid.cbar_axes[i + 2 * len(indices)].colorbar(im3)
    #
    #
    #
    #     grid[i].axis('off')
    #     grid[i+len(indices)].axis('off')
    #     grid[i + 2*len(indices)].axis('off')
    # fig.text(0.02, 0.82, 'Input', va='center', ha='center', rotation='vertical', fontsize=20)
    # fig.text(0.02, 0.5, 'Classification', va='center', ha='center', rotation='vertical', fontsize=20)
    # fig.text(0.02, 0.17, 'Diffusion', va='center', ha='center', rotation='vertical', fontsize=20)
    # plt.subplots_adjust(left=0.04, right=0.93, top=0.98, bottom=0.02,wspace=0.05, hspace=0.05)
    # plt.savefig(f'figures/diff_class_mnist_{signal_scale}_{time}.png', dpi=300)




    # # This affects all Axes because we set share_all = True.
    # grid.axes_llc.set_xticks([])
    # grid.axes_llc.set_yticks([])

    # plt.savefig('figures/diff_class_mnist.png', dpi=300)

    # loss curves
    plt.subplots(figsize=(7, 6))
    plt.plot(checkpoint_class['loss'], color=color_loss, linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    plt.tight_layout()
    plt.savefig(f'figures/class_loss_mnist_{signal_scale}_{time}.png', dpi=300)

    fig, ax=plt.subplots(figsize=(7, 6))
    plt.plot(torch.tensor(checkpoint_diff['loss']).detach().cpu(), color=color_loss, linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=7))
    plt.tight_layout()
    plt.savefig(f'figures/diff_loss_mnist_{signal_scale}_{time}.png', dpi=300)

    # acc
    plt.subplots(figsize=(7, 6))
    plt.plot(checkpoint_class['train_acc'], label='Train acc', color=color_acc_train, linewidth=3)
    plt.plot(checkpoint_class['test_acc'], label='Test acc', color=color_acc_test, linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    # plt.ylabel('ACC', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    plt.legend(fontsize=22)
    plt.tight_layout()
    plt.savefig(f'figures/class_acc_mnist_{signal_scale}_{time}.png', dpi=300)

    # feature learning
    fig, ax1 = plt.subplots(figsize=(7, 6))
    line1, = ax1.plot(checkpoint_class['feature_learning'], color=color_feature, linewidth=3, label='Feature (Class)', linestyle='--')
    line2, = ax1.plot(checkpoint_class['noise_memorization'], color=color_noise, linewidth=3, label='Noise (Class)', linestyle='--')
    ax2 = ax1.twiny()
    line3, = ax2.plot(checkpoint_diff['feature_learning'], color=color_feature, linewidth=3, label='Feature (Diff)')
    line4, = ax2.plot(checkpoint_diff['noise_memorization'], color=color_noise, linewidth=3, label='Noise (Diff)')
    ax1.set_xlabel('Iteration (Class)', fontsize=25)
    ax2.set_xlabel('Iteration (Diff)', fontsize=25)
    ax1.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    ax2.tick_params(axis='both', which='major', labelsize=15)
    ax2.xaxis.set_major_locator(MaxNLocator(nbins=5))
    lines = [line1, line2, line3, line4]
    labels_ = [line.get_label() for line in lines]
    ax1.legend(lines, labels_, loc='upper left', fontsize=23)

    plt.tight_layout()
    plt.savefig(f'figures/feat_mnist_{signal_scale}_{time}.png', dpi=300)

def parse_args(args):
    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, default='diff', choices = ['cnn', 'diff'])
    parser.add_argument("--time", type=float, default=0.8)
    parser.add_argument("--scale", type=float, default=0.1)

    args = parser.parse_args(args)

    return args


if __name__ == '__main__':
    seed = 100
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    args = parse_args(None)

    d = 784  # dimension for each patch
    m = 100  # Number of weight vectors
    time = args.time  # diffusion time
    signal_scale = args.scale
    noise_scale = 1
    max_epoch_cnn = 300
    max_epoch_diff = 50000
    n_eps = 2000

    alpha = np.exp(-time)
    beta = np.sqrt(1 - np.exp(-2 * time))
    print(alpha, beta)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    train_sample_per_class = 50
    batchsize = train_sample_per_class * 2

    # Custom transform to first normalize and then split into vertical patches
    transform = NormalizeAndSegmentVerticallyTransform()

    # Load the MNIST dataset with the custom transform
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the indices for classes 0 and 1 in training and test sets
    train_indices = filter_classes(train_dataset, samples_per_class=train_sample_per_class)
    test_indices = filter_classes(test_dataset, samples_per_class=-1)

    # Create subsets of the dataset
    train_dataset = Subset(train_dataset, train_indices)
    test_dataset = Subset(test_dataset, test_indices)

    train_images = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))]).to(device)
    train_labels = torch.tensor([train_dataset[i][1] for i in range(len(train_dataset))], dtype=torch.long).to(device)
    test_images = torch.stack([test_dataset[i][0] for i in range(len(test_dataset))]).to(device)
    test_labels = torch.tensor([test_dataset[i][1] for i in range(len(test_dataset))], dtype=torch.long).to(device)
    train_labels = 2 * train_labels - 1
    test_labels = 2 * test_labels - 1

    # train()

    indices = [1600, 1800, 1900, 0, 1, 2]
    visualize_all()
    #
    # image, label = train_dataset[13]
    # visualize_images(image)
