
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib
import torchvision
import random
import os

from tqdm import tqdm
from torchvision import datasets, transforms

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--noise_level", type=float, required=True, help="Noise level")
parser.add_argument("--signal_norm", type=float, required=True, help="Signal normalization factor")
args = parser.parse_args()

noise_level = args.noise_level
signal_norm = args.signal_norm

#output_file = "GD_results.csv"
output_file = "LNGD_results.csv"


class CNN(nn.Module):
    def __init__(self, m=50, d=1000, q=2,linear=False):
        super(CNN, self).__init__()

        self.q = q
        self.linear = linear
        self.Wp = torch.nn.Parameter(torch.randn(d, m))
        self.Wp.requires_grad = True
        self.Wn = torch.nn.Parameter(torch.randn(d, m))
        self.Wn.requires_grad = True

        nn.init.normal_(self.Wp, std=0.001)
        nn.init.normal_(self.Wn, std=0.001)

    def act(self,input):
        if self.linear:
            return input

        return torch.pow(F.relu(input),self.q)

    def forward(self, x1, x2):
        Fp = torch.mean(self.act(torch.mm(x1, self.Wp)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wp)), 1)
        Fn = torch.mean(self.act(torch.mm(x1, self.Wn)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wn)), 1)
        out = Fp - Fn
        return out


def prepare_data():
    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
           # transforms.Normalize((0.1307,), (0.3081,))
        ])
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
           # transforms.Normalize((0.1307,), (0.3081,))
        ])
    train_dataset = torchvision.datasets.MNIST(root="～/mnist_data", train=True, transform=train_transform, download=False)
    test_dataset = torchvision.datasets.MNIST(root="~/mnist_data", train=False, transform=test_transform, download=False)

    train_dataset.data = train_dataset.data.type(torch.float32)
    test_dataset.data = test_dataset.data.type(torch.float32)

    # Create mask for digits
    mask1 = (train_dataset.targets == 0)
    mask2 = (train_dataset.targets == 1)

    half_num = n_train // 2
    data1 = train_dataset.data[mask1][:half_num]
    targets1 = train_dataset.targets[mask1][:half_num]
    data2 = train_dataset.data[mask2][:half_num]
    targets2 = train_dataset.targets[mask2][:half_num]

    train_dataset.data = torch.cat((data1, data2))
    train_dataset.targets = torch.cat((targets1, targets2))

    # noise_level = 5.0
    # signal_norm = 80.0

    train_x1 = train_dataset.data/signal_norm
    train_x2 = noise_level * torch.randn_like(train_x1)

    train_x1 = train_x1.view(train_x1.size(0), -1)
    train_x2 = train_x2.view(train_x2.size(0), -1)

    train_y = train_dataset.targets.clone()
    train_y[train_y == 0] = -1
    train_y[train_y == 1] = 1

    mask1_test = (test_dataset.targets == 0)
    mask2_test = (test_dataset.targets == 1)

    data1_test = test_dataset.data[mask1_test][:n_test // 2]
    targets1_test = test_dataset.targets[mask1_test][:n_test // 2]
    data2_test = test_dataset.data[mask2_test][:n_test // 2]
    targets2_test = test_dataset.targets[mask2_test][:n_test // 2]

    test_dataset.data = torch.cat((data1_test, data2_test))
    test_dataset.targets = torch.cat((targets1_test, targets2_test))

    test_x1 = test_dataset.data / signal_norm
    test_x2 = noise_level * torch.randn_like(test_x1)

    test_x1 = test_x1.view(test_x1.size(0), -1)
    test_x2 = test_x2.view(test_x2.size(0), -1)

    test_y = test_dataset.targets.clone()
    test_y[test_y == 0] = -1
    test_y[test_y == 1] = 1

    return train_x1, train_x2, train_y, test_x1, test_x2, test_y

# seed = 3407
seed = 2023
n_train = 100
n_test = 1000
d = 784
n_epoch = 200000

np.random.seed(seed)
torch.manual_seed(seed)

train_x1, train_x2, train_y, test_x1, test_x2, test_y = prepare_data()

#print(train_x1[0])

#print(train_x1,train_x2,train_y)
#print(test_x1,test_x2,test_y)

sample_size = n_train
data_loader = DataLoader(TensorDataset(
    train_x1,
    train_x2,
    train_y
), batch_size=int(250), shuffle=True)


width = 20
learning_rate = 0.001

model = CNN(m=width, d=d)
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

train_f_preds = []
loss_derivatives = []


train_loss_values = []
test_loss_values = []
train_acc_values = []

test_acc_values = []
feature_learning = []
noise_memorization_p = np.zeros(( width, n_train, n_epoch))
feature_learning_p = np.zeros(( width,  n_epoch))

noise_memorization_n = np.zeros(( width, n_train, n_epoch))
feature_learning_n = np.zeros(( width,  n_epoch))

for ep in range(n_epoch):
    train_loss = 0

    loss_d = []

    for sample_x1, sample_x2, sample_y in data_loader:

        model.train()
        optimizer.zero_grad()
        f_pred = model.forward(sample_x1, sample_x2)
        noise = 2*torch.bernoulli(torch.ones_like(sample_y)*0.85)-1
        # noise = 1
        loss = torch.log(torch.add(torch.exp(-f_pred * sample_y * noise), 1)).mean()
        loss_d.append(1/(torch.exp(f_pred*sample_y*noise) + 1))

        loss.backward()
        optimizer.step()
        model.eval()
        train_loss += sample_size * loss.item()


    feature_learning_p[:, ep] =  (torch.matmul(model.Wp.T, train_x1[0])).detach().numpy()
    noise_memorization_p[:,:, ep] =  (torch.matmul(model.Wp.T, train_x2.T)).detach().numpy()

    feature_learning_n[:, ep] =  (torch.matmul(model.Wn.T, train_x1[0])).detach().numpy()
    noise_memorization_n[:,:, ep] =  (torch.matmul(model.Wn.T, train_x2.T)).detach().numpy()


    train_loss /= n_train
    train_loss_values.append(train_loss)
    f_pred_test = model.forward(test_x1, test_x2)
    f_pred_train = model.forward(train_x1, train_x2)

    train_f_preds.append(f_pred_train)
    loss_derivatives.append(torch.cat(loss_d))

    # print(torch.cat(loss_d).shape)

    pred_binary_train = (f_pred_train > 0).float() * 2 - 1
    pred_binary_test = (f_pred_test > 0).float() * 2 - 1


    correct_preds_train = (pred_binary_train == train_y).float().mean()
    correct_preds_test = (pred_binary_test == test_y).float().mean()

    train_acc_values.append(correct_preds_train.item())
    test_acc_values.append(correct_preds_test.item())

    test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
    test_loss_values.append(test_loss.item())
#
#     print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')

with open(output_file, "a") as f:
    f.write(f"{noise_level},{signal_norm},{correct_preds_test}\n")

#print(f"Saved results to {output_file}")

# model = CNN(m=width, d=d)
#
# optimizer = torch.optim.SGD(model.parameters(), lr =learning_rate)
#
# train_loss_values_sgd = []
# test_loss_values_sgd = []
# train_acc_values_sgd = []
#
# test_acc_values_sgd = []
# feature_learning = []
#
# noise_memorization_p_sgd = np.zeros(( width, n_train, n_epoch))
# feature_learning_p_sgd = np.zeros(( width,  n_epoch))
#
# noise_memorization_n_sgd = np.zeros(( width, n_train, n_epoch))
# feature_learning_n_sgd = np.zeros(( width,  n_epoch))
#
# for ep in range(n_epoch):
#     train_loss = 0
#     for sample_x1, sample_x2, sample_y in data_loader:
#
#         model.train()
#         optimizer.zero_grad()
#         f_pred = model.forward(sample_x1, sample_x2)
#         loss = torch.log(torch.add(torch.exp(-f_pred * sample_y ), 1)).mean()
#
#         loss.backward()
#         optimizer.step()
#         model.eval()
#         train_loss += sample_size * loss.item()
#
#
#     feature_learning_p_sgd[:, ep] =  (torch.matmul(model.Wp.T, train_x1[0])).detach().numpy()
#     noise_memorization_p_sgd[:,:, ep] =  (torch.matmul(model.Wp.T, train_x2.T)).detach().numpy()
#
#     feature_learning_n_sgd[:, ep] =  (torch.matmul(model.Wn.T, train_x1[0])).detach().numpy()
#     noise_memorization_n_sgd[:,:, ep] =  (torch.matmul(model.Wn.T, train_x2.T)).detach().numpy()
#
#
#     train_loss /= n_train
#     train_loss_values_sgd.append(train_loss)
#     f_pred_test = model.forward(test_x1, test_x2)
#     f_pred_train = model.forward(train_x1, train_x2)
#
#     pred_binary_train = (f_pred_train > 0).float() * 2 - 1
#     pred_binary_test = (f_pred_test > 0).float() * 2 - 1
#
#
#     correct_preds_train = (pred_binary_train == train_y).float().mean()
#     correct_preds_test = (pred_binary_test == test_y).float().mean()
#
#     train_acc_values_sgd.append(correct_preds_train.item())
#     test_acc_values_sgd.append(correct_preds_test.item())
#
#     test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
#     test_loss_values_sgd.append(test_loss.item())
#
#     # print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')
#
# with open(output_file, "a") as f:
#     f.write(f"{noise_level},{signal_norm},{correct_preds_test}\n")
#


# matplotlib.rcParams['pdf.fonttype'] = 42
# matplotlib.rcParams['ps.fonttype'] = 42
# matplotlib.rcParams['font.size'] = 16
#
# # Create a 1x2 grid of subplots
# fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 4.5))
#
# # Plot the matrix_norm_numpy values in the first subplot
#
# feature_learning_tensor = torch.tensor(feature_learning)
# feature_learning_numpy_array = feature_learning_tensor.detach().numpy()

# Plot the loss_values in the first subplot

# noise_pseris = np.max(np.abs(noise_memorization_p), axis=0)
# noise_nseris = np.max(np.abs(noise_memorization_n), axis=0)
# noise_mseris = np.maximum(noise_pseris, noise_nseris)
#
# feature_pseris = np.max(np.abs(feature_learning_p), axis=0)
# feature_nseris = np.max(np.abs(feature_learning_n), axis=0)
# feature_mseris = np.maximum(feature_pseris, feature_nseris)
#
#
# noise_pseris_sgd = np.max(np.abs(noise_memorization_p_sgd), axis=0)
# noise_nseris_sgd = np.max(np.abs(noise_memorization_n_sgd), axis=0)
# noise_mseris_sgd = np.maximum(noise_pseris_sgd, noise_nseris_sgd)
#
# feature_pseris_sgd = np.max(np.abs(feature_learning_p_sgd), axis=0)
# feature_nseris_sgd = np.max(np.abs(feature_learning_n_sgd), axis=0)
# feature_mseris_sgd = np.maximum(feature_pseris_sgd, feature_nseris_sgd)
#
#
#
# ax1.plot((noise_mseris_sgd[0].T), color = 'tab:red',linestyle='--', linewidth =2, label = r'$\max_{j,r} \rho_{j,r,i}$ (GD)' )
# ax1.plot((feature_mseris_sgd), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (GD)', color='tab:red')
# ax1.plot((noise_mseris[0].T), linewidth =2, color = 'tab:blue', linestyle='--', label = r'$\max_{j,r} \rho_{j,r,i}$ (Label Noise GD)' )
# ax1.plot((feature_mseris), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (Label Noise GD)', color='tab:blue')
#
#
# ax1.set_title('Feature Learning')
# ax1.set_xlabel('t',fontsize=20)
# ax1.tick_params(axis='both', which='major', labelsize=20)
# ax1.legend()
#
#
# ax2.plot(noise_mseris[0].T/feature_mseris, linewidth =2, color = 'tab:blue', label='Label Noise GD')
# ax2.plot(noise_mseris_sgd[0].T/feature_mseris_sgd, linewidth =2, color = 'tab:red', label='GD')
# ax2.set_title(r'Feature Learning Ratio ($\rho/\gamma$)')
# ax2.set_xlabel('t',fontsize=20)
# ax2.tick_params(axis='both', which='major', labelsize=20)
# ax2.legend()
#
# ax3.plot(train_loss_values, linewidth =2, label='Label Noise GD', color = 'tab:blue')
# ax3.plot(train_loss_values_sgd, linewidth =2,label='GD',  color='tab:red')
# ax3.set_title('Train Loss')
# ax3.set_xlabel('t',fontsize=20)
#
# ax3.legend()
# ax3.tick_params(axis='both', which='major', labelsize=20)
# ax4.plot(test_acc_values,linewidth =2, label='Label Noise GD',color = 'tab:blue')
# ax4.plot(test_acc_values_sgd, linewidth =2, label='GD', color='tab:red')
# ax4.set_title('Test Accuracy')
# ax4.set_xlabel('t',fontsize=20)
# ax4.legend()
# ax4.tick_params(axis='both', which='major', labelsize=20)
# fig.savefig('mnist_patch.png', dpi=300, bbox_inches='tight')
#
# plt.show()