
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib
import torchvision
import random

from tqdm import tqdm
from tensorflow.keras.datasets import mnist
from torchvision import datasets, transforms


class CNN(nn.Module):
    def __init__(self, m=50, d=1000, q=2,linear=False):
        super(CNN, self).__init__()

        self.q = q
        self.linear = linear
        self.Wp = torch.nn.Parameter(torch.randn(d, m))
        self.Wp.requires_grad = True
        self.Wn = torch.nn.Parameter(torch.randn(d, m))
        self.Wn.requires_grad = True

        nn.init.normal_(self.Wp, std=0.001)
        nn.init.normal_(self.Wn, std=0.001)

    def act(self,input):
        if self.linear:
            return input

        return torch.pow(F.relu(input),self.q)

    def forward(self, x1, x2):
        Fp = torch.mean(self.act(torch.mm(x1, self.Wp)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wp)), 1)
        Fn = torch.mean(self.act(torch.mm(x1, self.Wn)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wn)), 1)
        out = Fp - Fn
        return out


def prepare_data():
    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    train_dataset = torchvision.datasets.CIFAR10(root="~/cifar10_data", train=True, transform=train_transform,
                                                 download=True)
    test_dataset = torchvision.datasets.CIFAR10(root="~/cifar10_data", train=False, transform=test_transform,
                                                download=True)

    # Create mask for classes (e.g., airplane and automobile)
    train_data = torch.tensor(train_dataset.data).permute(0, 3, 1, 2).float()
    train_targets = torch.tensor(train_dataset.targets)
    test_data = torch.tensor(test_dataset.data).permute(0, 3, 1, 2).float()
    test_targets = torch.tensor(test_dataset.targets)


    mask1 = (train_targets == 0)
    mask2 = (train_targets == 1)

    half_num = min(mask1.sum().item(), mask2.sum().item(), n_train // 2)
    if half_num == 0:
        raise ValueError("Not enough data available for the selected classes.")

    data1 = train_data[mask1][:half_num]
    targets1 = train_targets[mask1][:half_num]
    data2 = train_data[mask2][:half_num]
    targets2 = train_targets[mask2][:half_num]

    train_data = torch.cat((data1, data2))
    train_targets = torch.cat((targets1, targets2))

    noise_level = 5.0
    signal_norm = 64.0

    train_x1 = train_data / signal_norm
    train_x2 = noise_level * torch.randn_like(train_x1)

    train_x1 = train_x1.view(train_x1.size(0), -1)
    train_x2 = train_x2.view(train_x2.size(0), -1)

    train_y = train_targets.clone()
    train_y[train_y == 0] = -1
    train_y[train_y == 1] = 1


    mask1_test = (test_targets == 0)
    mask2_test = (test_targets == 1)

    half_num_test = min(mask1_test.sum().item(), mask2_test.sum().item(), n_test // 2)
    if half_num_test == 0:
        raise ValueError("Not enough test data available for the selected classes.")

    data1_test = test_data[mask1_test][:half_num_test]
    targets1_test = test_targets[mask1_test][:half_num_test]
    data2_test = test_data[mask2_test][:half_num_test]
    targets2_test = test_targets[mask2_test][:half_num_test]

    test_data = torch.cat((data1_test, data2_test))
    test_targets = torch.cat((targets1_test, targets2_test))

    test_x1 = test_data / signal_norm
    test_x2 = noise_level * torch.randn_like(test_x1)

    test_x1 = test_x1.view(test_x1.size(0), -1)
    test_x2 = test_x2.view(test_x2.size(0), -1)

    test_y = test_targets.clone()
    test_y[test_y == 0] = -1
    test_y[test_y == 1] = 1

    return train_x1, train_x2, train_y, test_x1, test_x2, test_y

# seed = 3407
seed = 2023
n_train = 100
n_test = 1000
d = 3072
n_epoch = 20000



np.random.seed(seed)
torch.manual_seed(seed)

train_x1, train_x2, train_y, test_x1, test_x2, test_y = prepare_data()



#print(train_x1[0])

#print(train_x1,train_x2,train_y)
#print(test_x1,test_x2,test_y)

sample_size = n_train
data_loader = DataLoader(TensorDataset(
    train_x1,
    train_x2,
    train_y
), batch_size=int(250), shuffle=True)


width = 20
learning_rate = 0.001

model = CNN(m=width, d=d)
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

train_f_preds = []
loss_derivatives = []


train_loss_values = []
test_loss_values = []
train_acc_values = []

test_acc_values = []
feature_learning = []
noise_memorization_p = np.zeros(( width, n_train, n_epoch))
feature_learning_p = np.zeros(( width,  n_epoch))

noise_memorization_n = np.zeros(( width, n_train, n_epoch))
feature_learning_n = np.zeros(( width,  n_epoch))

for ep in range(n_epoch):
    train_loss = 0

    loss_d = []

    for sample_x1, sample_x2, sample_y in data_loader:

        model.train()
        optimizer.zero_grad()
        f_pred = model.forward(sample_x1, sample_x2)
        noise = 2*torch.bernoulli(torch.ones_like(sample_y)*0.85)-1
        # noise = 1
        loss = torch.log(torch.add(torch.exp(-f_pred * sample_y * noise), 1)).mean()
        loss_d.append(1/(torch.exp(f_pred*sample_y*noise) + 1))

        loss.backward()
        optimizer.step()
        model.eval()
        train_loss += sample_size * loss.item()


    feature_learning_p[:, ep] =  (torch.matmul(model.Wp.T, train_x1[0])).detach().numpy()
    noise_memorization_p[:,:, ep] =  (torch.matmul(model.Wp.T, train_x2.T)).detach().numpy()

    feature_learning_n[:, ep] =  (torch.matmul(model.Wn.T, train_x1[0])).detach().numpy()
    noise_memorization_n[:,:, ep] =  (torch.matmul(model.Wn.T, train_x2.T)).detach().numpy()


    train_loss /= n_train
    train_loss_values.append(train_loss)
    f_pred_test = model.forward(test_x1, test_x2)
    f_pred_train = model.forward(train_x1, train_x2)

    train_f_preds.append(f_pred_train)
    loss_derivatives.append(torch.cat(loss_d))

    # print(torch.cat(loss_d).shape)

    pred_binary_train = (f_pred_train > 0).float() * 2 - 1
    pred_binary_test = (f_pred_test > 0).float() * 2 - 1


    correct_preds_train = (pred_binary_train == train_y).float().mean()
    correct_preds_test = (pred_binary_test == test_y).float().mean()

    train_acc_values.append(correct_preds_train.item())
    test_acc_values.append(correct_preds_test.item())

    test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
    test_loss_values.append(test_loss.item())

    print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')


model = CNN(m=width, d=d)

optimizer = torch.optim.SGD(model.parameters(), lr =learning_rate)

train_loss_values_sgd = []
test_loss_values_sgd = []
train_acc_values_sgd = []

test_acc_values_sgd = []
feature_learning = []

noise_memorization_p_sgd = np.zeros(( width, n_train, n_epoch))
feature_learning_p_sgd = np.zeros(( width,  n_epoch))

noise_memorization_n_sgd = np.zeros(( width, n_train, n_epoch))
feature_learning_n_sgd = np.zeros(( width,  n_epoch))

for ep in range(n_epoch):
    train_loss = 0
    for sample_x1, sample_x2, sample_y in data_loader:

        model.train()
        optimizer.zero_grad()
        f_pred = model.forward(sample_x1, sample_x2)
        loss = torch.log(torch.add(torch.exp(-f_pred * sample_y ), 1)).mean()

        loss.backward()
        optimizer.step()
        model.eval()
        train_loss += sample_size * loss.item()


    feature_learning_p_sgd[:, ep] =  (torch.matmul(model.Wp.T, train_x1[0])).detach().numpy()
    noise_memorization_p_sgd[:,:, ep] =  (torch.matmul(model.Wp.T, train_x2.T)).detach().numpy()

    feature_learning_n_sgd[:, ep] =  (torch.matmul(model.Wn.T, train_x1[0])).detach().numpy()
    noise_memorization_n_sgd[:,:, ep] =  (torch.matmul(model.Wn.T, train_x2.T)).detach().numpy()


    train_loss /= n_train
    train_loss_values_sgd.append(train_loss)
    f_pred_test = model.forward(test_x1, test_x2)
    f_pred_train = model.forward(train_x1, train_x2)

    pred_binary_train = (f_pred_train > 0).float() * 2 - 1
    pred_binary_test = (f_pred_test > 0).float() * 2 - 1


    correct_preds_train = (pred_binary_train == train_y).float().mean()
    correct_preds_test = (pred_binary_test == test_y).float().mean()

    train_acc_values_sgd.append(correct_preds_train.item())
    test_acc_values_sgd.append(correct_preds_test.item())

    test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
    test_loss_values_sgd.append(test_loss.item())

    print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.size'] = 16




# Plot the matrix_norm_numpy values in the first subplot

feature_learning_tensor = torch.tensor(feature_learning)
feature_learning_numpy_array = feature_learning_tensor.detach().numpy()

# Plot the loss_values in the first subplot

noise_pseris = np.max(np.abs(noise_memorization_p), axis=0)
noise_nseris = np.max(np.abs(noise_memorization_n), axis=0)
noise_mseris = np.maximum(noise_pseris, noise_nseris)

feature_pseris = np.max(np.abs(feature_learning_p), axis=0)
feature_nseris = np.max(np.abs(feature_learning_n), axis=0)
feature_mseris = np.maximum(feature_pseris, feature_nseris)


noise_pseris_sgd = np.max(np.abs(noise_memorization_p_sgd), axis=0)
noise_nseris_sgd = np.max(np.abs(noise_memorization_n_sgd), axis=0)
noise_mseris_sgd = np.maximum(noise_pseris_sgd, noise_nseris_sgd)

feature_pseris_sgd = np.max(np.abs(feature_learning_p_sgd), axis=0)
feature_nseris_sgd = np.max(np.abs(feature_learning_n_sgd), axis=0)
feature_mseris_sgd = np.maximum(feature_pseris_sgd, feature_nseris_sgd)



# Create a 1x2 grid of subplots
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 4.5))

import matplotlib.ticker as ticker

ax1.xaxis.set_major_locator(ticker.MaxNLocator(nbins=3))  # 将 x 轴最多分成 10 个刻度
ax2.xaxis.set_major_locator(ticker.MaxNLocator(nbins=3))
ax3.xaxis.set_major_locator(ticker.MaxNLocator(nbins=3))
ax4.xaxis.set_major_locator(ticker.MaxNLocator(nbins=3))


ax1.plot((noise_mseris_sgd[0].T), color = 'tab:red',linestyle='--', linewidth =2, label = r'$\max_{j,r} \rho_{j,r,i}$ (GD)' )
ax1.plot((feature_mseris_sgd), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (GD)', color='tab:red')
ax1.plot((noise_mseris[0].T), linewidth =2, color = 'tab:blue', linestyle='--', label = r'$\max_{j,r} \rho_{j,r,i}$ (Label Noise GD)' )
ax1.plot((feature_mseris), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (Label Noise GD)', color='tab:blue')

# ax1.set_yscale('log')
# ax1.set_xscale('log')


ax1.set_title('Feature Learning')
ax1.set_xlabel('t',fontsize=20)
ax1.tick_params(axis='both', which='major', labelsize=20)
ax1.legend()


ax2.plot(noise_mseris[0].T/feature_mseris, linewidth =2, color = 'tab:blue', label='Label Noise GD')
ax2.plot(noise_mseris_sgd[0].T/feature_mseris_sgd, linewidth =2, color = 'tab:red', label='GD')
ax2.set_title(r'Feature Learning Ratio ($\rho/\gamma$)')
ax2.set_xlabel('t',fontsize=20)
ax2.tick_params(axis='both', which='major', labelsize=20)
ax2.legend()

ax3.plot(train_loss_values, linewidth =2, label='Label Noise GD', color = 'tab:blue')
ax3.plot(train_loss_values_sgd, linewidth =2,label='GD',  color='tab:red')
ax3.set_title('Train Loss')
ax3.set_xlabel('t',fontsize=20)

ax3.legend()
ax3.tick_params(axis='both', which='major', labelsize=20)
# ax4.plot(train_acc_flip_values, label='train_flip',color = 'red')
# ax3.plot(train_acc_unflip_values, label='train_unflip',color = 'green')
ax4.plot(test_acc_values,linewidth =2, label='Label Noise GD',color = 'tab:blue')
ax4.plot(test_acc_values_sgd, linewidth =2, label='GD', color='tab:red')
ax4.set_title('Test Accuracy')
ax4.set_xlabel('t',fontsize=20)
ax4.legend()
#ax4.plot([0, 2000], [1, 1], label='y=1')
ax4.tick_params(axis='both', which='major', labelsize=20)



fig.savefig('cifar_patch.png', dpi=300, bbox_inches='tight')

plt.show()