# continue training
import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from torch.utils.data import DataLoader, Subset
import random
import os
from mpl_toolkits.axes_grid1 import AxesGrid

from mpl_toolkits.axes_grid1 import make_axes_locatable

import argparse
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

color_noise = 'tab:blue'
color_feature = 'lightcoral'
color_loss = 'tab:blue'
color_acc_train = 'tab:blue'
color_acc_test = 'lightcoral'


def filter_classes(dataset, classes=(0, 1), samples_per_class=100):
    targets = dataset.targets
    class_indices = []

    for class_label in classes:
        indices = torch.nonzero(targets == class_label).squeeze()
        class_indices.append(indices[:samples_per_class])  # Take only the first `samples_per_class` samples

    return torch.cat(class_indices)  # Concatenate indices of both classes

class NormalizeAndSegmentVerticallyTransform:
    def __init__(self):
        # Normalization parameters (single-channel normalization for grayscale)
        self.normalize = transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale image with mean=0.5, std=0.5


    def __call__(self, img):
        img = transforms.ToTensor()(img)  # Convert image to tensor
        img = self.normalize(img)  # Apply normalization

        noise_patch = torch.randn_like(img)

        patches = [img * signal_scale, noise_patch * noise_scale]
        patches = torch.cat(patches, dim=0)
        return patches.view(2, -1)


class CNN_diff(nn.Module):
    def __init__(self, d, m, act='quad'):
        super(CNN_diff, self).__init__()
        self.act = act
        self.m = m
        self.W = nn.Parameter(torch.randn(m, d) * 0.01)

    def activation(self, x):
        if self.act == 'relu':
            return F.relu(x)
        elif self.act == 'quad':
            return x ** 2
        else:
            raise NotImplementedError

    def forward(self, x):
        # x is of shape [batch_size, *, num_patches, d]
        inner_products = torch.einsum('...nd,md->...nm', x, self.W)
        inner_products = self.activation(inner_products)
        output = torch.einsum('...nm,md->...nd', inner_products, self.W) / np.sqrt(self.m)
        return output


class CNN_class(nn.Module):
    def __init__(self, d, m, act='quad'):
        super(CNN_class, self).__init__()
        self.m = m
        self.act = act
        self.Wp = nn.Parameter(torch.randn(m, d) * 0.01)
        self.Wn = nn.Parameter(torch.randn(m, d) * 0.01)

    def activation(self, x):
        if self.act == 'relu':
            return F.relu(x)
        elif self.act == 'quad':
            return x ** 2
        else:
            raise NotImplementedError

    def forward(self, x):
        # x is of shape [batch_size, num_patches, d]
        inner_products_p = torch.einsum('...nd,md->...nm', x, self.Wp)
        inner_products_n = torch.einsum('...nd,md->...nm', x, self.Wn)
        inner_products_p = self.activation(inner_products_p)
        inner_products_n = self.activation(inner_products_n)
        output = inner_products_p.mean(-1).sum(-1) - inner_products_n.mean(-1).sum(-1) # mean over neurons and sum over patch
        return output



def train():
    # train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=False)
    # test_loader = DataLoader(dataset=test_dataset, batch_size=batchsize, shuffle=False)

    if args.model == 'cnn':
        model = CNN_class(d=d, m=m)
        checkpoint_class = torch.load(f'class_model_checkpoint_{signal_scale}_{time}.pth', map_location='cpu')
        model.load_state_dict(checkpoint_class['model_state_dict'])
        lr = 0.5
        epochs = max_epoch_cnn
    elif args.model == 'diff':
        model = CNN_diff(d=d, m=m)
        checkpoint_diff = torch.load(f'diff_model_checkpoint_{signal_scale}_{time}.pth', map_location='cpu')
        model.load_state_dict(checkpoint_diff['model_state_dict'])
        lr = 0.5
        epochs = max_epoch_diff
    else:
        raise NotImplementedError

    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    if args.model == 'diff':
        losses = []
        noise_memorization = np.zeros((m, 100, max_epoch_diff))
        feature_learning = np.zeros((m, 100, max_epoch_diff))

        for epoch in range(epochs):
            model.train()
            eps = torch.randn(train_images.shape[0],n_eps,train_images.shape[1],train_images.shape[2]).to(device)
            noisy_images = alpha * train_images.unsqueeze(1) + beta * eps
            # eps = torch.randn_like(train_images)
            # noisy_images = alpha * train_images + beta * eps
            outputs = model(noisy_images)

            optimizer.zero_grad()
            loss = F.mse_loss(outputs, eps)

            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            model.eval()
            with torch.no_grad():
                feature_learning[:, :, epoch] = (torch.matmul(model.W, train_images[:, 0, :].T)).cpu().detach().numpy()
                noise_memorization[:, :, epoch] = (torch.matmul(model.W, train_images[:, 1, :].T)).cpu().detach().numpy()

            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {losses[-1]:.4f}")

        feature_learning = np.max(np.abs(feature_learning), axis=0)
        feature_learning = np.mean(feature_learning, axis=0)

        noise_memorization = np.max(np.abs(noise_memorization), axis=0)
        noise_memorization = np.mean(noise_memorization, axis=0)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': losses,
            'feature_learning': feature_learning,
            'noise_memorization': noise_memorization
        }

        torch.save(checkpoint, f'diff_model_checkpoint_{signal_scale}_{time}_continue.pth')

    elif args.model == 'cnn':
        losses = []
        test_accs = []
        train_accs = []
        noise_memorization_p = np.zeros((m, 100, max_epoch_cnn))
        feature_learning_p = np.zeros((m, 100, max_epoch_cnn))
        noise_memorization_n = np.zeros((m, 100, max_epoch_cnn))
        feature_learning_n = np.zeros((m, 100, max_epoch_cnn))

        for epoch in range(epochs):
            model.train()

            f_pred = model(train_images)
            loss = torch.log(torch.add(torch.exp(-f_pred * train_labels), 1)).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            # train test accuracy
            model.eval()
            with torch.no_grad():
                feature_learning_p[:, :, epoch] = (torch.matmul(model.Wp, train_images[:,0,:].T)).cpu().detach().numpy()
                feature_learning_n[:, :, epoch] = (torch.matmul(model.Wn, train_images[:, 0, :].T)).cpu().detach().numpy()
                noise_memorization_p[:, :, epoch] = (torch.matmul(model.Wp, train_images[:,1,:].T)).cpu().detach().numpy()
                noise_memorization_n[:, :, epoch] = (torch.matmul(model.Wn, train_images[:,1,:].T)).cpu().detach().numpy()
                train_acc = calculate_accuracy(train_images, train_labels, model)
                test_acc = calculate_accuracy(test_images, test_labels, model)
                train_accs.append(train_acc)
                test_accs.append(test_acc)

            print(
                f"Epoch [{epoch + 1}/{epochs}], Loss: {losses[-1]:.4f}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}")

        feature_learning_p = np.max(np.abs(feature_learning_p), axis=0)
        feature_learning_n = np.max(np.abs(feature_learning_n), axis=0)
        feature_learning = np.maximum(feature_learning_p, feature_learning_n)
        feature_learning = np.mean(feature_learning, axis=0)

        noise_memorization_p = np.max(np.abs(noise_memorization_p), axis=0)
        noise_memorization_n = np.max(np.abs(noise_memorization_n), axis=0)
        noise_memorization = np.maximum(noise_memorization_p, noise_memorization_n)
        noise_memorization = np.mean(noise_memorization, axis=0)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': losses,
            'test_acc': test_accs,
            'train_acc': train_accs,
            'feature_learning': feature_learning,
            'noise_memorization': noise_memorization
        }

        torch.save(checkpoint, f'class_model_checkpoint_{signal_scale}_continue.pth')



def calculate_accuracy(images, labels, model):
    # take in images = [N, 2, 786]
    # labels = [N] of -1 and +1
    model.eval()
    total = images.shape[0]
    with torch.no_grad():
        f_pred = model(images)
        correct = ((labels * f_pred) > 0).sum().item()

    return correct / total



def parse_args(args):
    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, default='cnn', choices = ['cnn', 'diff'])
    parser.add_argument("--time", type=float, default=0.2)
    parser.add_argument("--scale", type=float, default=0.1)

    args = parser.parse_args(args)

    return args



def load_and_save():
    # load the old and save by combing with the new

    checkpoint_diff = torch.load(f'diff_model_checkpoint_{signal_scale}_{time}.pth', map_location='cpu')
    # model_diff = CNN_diff(d=d, m=m)
    # model_diff.load_state_dict(checkpoint_diff['model_state_dict'])

    feature_learning1 = checkpoint_diff['feature_learning']
    noise_memorization1 = checkpoint_diff['noise_memorization']
    losses = checkpoint_diff['loss']


    checkpoint_diff_cont = torch.load(f'diff_model_checkpoint_{signal_scale}_{time}_continue.pth', map_location='cpu')
    feature_learning2 = checkpoint_diff_cont['feature_learning']
    noise_memorization2 = checkpoint_diff_cont['noise_memorization']
    # model_diff = CNN_diff(d=d, m=m)
    # model_diff.load_state_dict(checkpoint_diff_cont['model_state_dict'])

    losses = checkpoint_diff['loss'] + checkpoint_diff_cont['loss']
    feature_learning = np.concatenate((feature_learning1, feature_learning2))
    noise_memorization = np.concatenate((noise_memorization1, noise_memorization2))
    # noise_memorization = np.concatenate((checkpoint_diff['noise_memorization'], checkpoint_diff_cont['noise_memorization']))

    checkpoint = {
        'epoch': checkpoint_diff_cont['epoch'] + checkpoint_diff['epoch'],
        'model_state_dict': checkpoint_diff_cont['model_state_dict'],
        'loss': losses,
        'feature_learning': feature_learning,
        'noise_memorization': noise_memorization
    }

    torch.save(checkpoint, f'diff_model_checkpoint_{signal_scale}_{time}.pth')




if __name__ == '__main__':
    seed = 100
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    args = parse_args(None)

    d = 784  # dimension for each patch
    m = 100  # Number of weight vectors
    time = args.time  # diffusion time
    signal_scale = args.scale
    noise_scale = 1
    max_epoch_cnn = 200
    max_epoch_diff = 10000
    n_eps = 2000

    alpha = np.exp(-time)
    beta = np.sqrt(1 - np.exp(-2 * time))
    print(alpha, beta)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    train_sample_per_class = 50
    batchsize = train_sample_per_class * 2

    # Custom transform to first normalize and then split into vertical patches
    transform = NormalizeAndSegmentVerticallyTransform()

    # Load the MNIST dataset with the custom transform
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the indices for classes 0 and 1 in training and test sets
    train_indices = filter_classes(train_dataset, samples_per_class=train_sample_per_class)
    test_indices = filter_classes(test_dataset, samples_per_class=-1)

    # Create subsets of the dataset
    train_dataset = Subset(train_dataset, train_indices)
    test_dataset = Subset(test_dataset, test_indices)

    train_images = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))]).to(device)
    train_labels = torch.tensor([train_dataset[i][1] for i in range(len(train_dataset))], dtype=torch.long).to(device)
    test_images = torch.stack([test_dataset[i][0] for i in range(len(test_dataset))]).to(device)
    test_labels = torch.tensor([test_dataset[i][1] for i in range(len(test_dataset))], dtype=torch.long).to(device)
    train_labels = 2 * train_labels - 1
    test_labels = 2 * test_labels - 1

    # train()

    load_and_save()