import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import random
from torch import nn
import torch.nn.functional  as F
from sklearn.linear_model import LogisticRegression

class NN(nn.Module):
    def __init__(self, m=50, d=2000, q=1,linear=False):
        super(NN, self).__init__()

        self.q = q
        self.linear = linear
        self.W = torch.nn.Parameter(torch.randn(d, m))

        nn.init.normal_(self.W, std=0.0001)

    def act(self,input):
        if self.linear:
            return input
        #return torch.pow(F.leaky_relu(input, negative_slope=0.5),self.q)
        return torch.pow(F.relu(input),self.q)

    def forward(self, x1, x2, verbose=False):

        Fmu = self.act(torch.mm(x1, self.W))
        Fxi = self.act(torch.mm(x2, self.W))
        return Fmu, Fxi



def clip_prepare_data():
    xi = 1.0
    mu = 5.0
    vis_train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    vis_test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))


    txt_train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    txt_test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))

    txt_feature = torch.zeros(d, 1)
    txt_feature[0] = mu * 3
    txt_train_x1 = torch.matmul(txt_train_y.unsqueeze(0).T, txt_feature.T)
    txt_test_x1 = torch.matmul(txt_test_y.unsqueeze(0).T, txt_feature.T)
    txt_train_x2 = torch.randn(n_train, d)*xi
    txt_test_x2 = torch.randn(n_test, d)*xi

    vis_feature = torch.zeros(d, 1)
    vis_feature[1] = mu
    vis_train_x1 = torch.matmul(vis_train_y.unsqueeze(0).T, vis_feature.T)
    vis_test_x1 = torch.matmul(vis_test_y.unsqueeze(0).T, vis_feature.T)
    vis_train_x2 = torch.randn(n_train, d)*xi
    vis_test_x2 = torch.randn(n_test, d)*xi
    #vis_train_x2[:,0] = 0

    return vis_train_x1, vis_train_x2, vis_train_y,  txt_train_x1, txt_train_x2, txt_train_y



def create_test_data(d):
    mu = 2
    xi = 1

    vis_train_y = torch.cat((torch.ones(int(n_train / 2)), -torch.ones(int(n_train / 2))))
    vis_test_y = torch.cat((torch.ones(int(n_test / 2)), -torch.ones(int(n_test / 2))))

    vis_feature = torch.zeros(d, 1)
    vis_feature[1] = mu
    vis_train_x1 = torch.matmul(vis_train_y.unsqueeze(0).T, vis_feature.T)
    vis_test_x1 = torch.matmul(vis_test_y.unsqueeze(0).T, vis_feature.T)
    vis_train_x2 = torch.randn(n_train, d) * xi
    vis_test_x2 = torch.randn(n_test, d) * xi

    return vis_train_x1, vis_train_x2, vis_train_y,  vis_test_x1, vis_test_x2, vis_test_y







def main_clip(n_train, n_test, d, n_epoch, tau, width, eval_every):

    vis_train_x1, vis_train_x2, vis_train_y, txt_train_x1, txt_train_x2, txt_train_y = clip_prepare_data()
    test_train_x1, test_train_x2, test_train_y, test_test_x1, test_test_x2, test_test_y= create_test_data(d)

    model1 = NN(m=width, d=d)
    sample_size = n_train
    data_loader1 = DataLoader(TensorDataset(
        vis_train_x1,
        vis_train_x2,
        vis_train_y
    ), batch_size=int(n_train), shuffle=False)

    model2 = NN(m=width, d=d)
    sample_size = n_train
    data_loader2 = DataLoader(TensorDataset(
        txt_train_x1,
        txt_train_x2,
        txt_train_y
    ), batch_size=int(n_train), shuffle=False)

    optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.01)
    optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.01)

    train_loss_values = []
    train_acc_values = []

    vis_noise_memorization = np.zeros((width, n_train, n_epoch))
    vis_feature_learning = np.zeros((width, n_epoch))

    txt_noise_memorization = np.zeros((width, n_train, n_epoch))
    txt_feature_learning = np.zeros((width, n_epoch))

    vis_sig_pos_similarity = np.zeros((n_train, n_epoch))
    vis_noi_pos_similarity = np.zeros((n_train, n_epoch))

    initial_vis_mu = torch.matmul(model1.W.T, vis_train_x1[1]).detach().numpy()
    initial_txt_mu = torch.matmul(model2.W.T, txt_train_x1[0]).detach().numpy()

    initial_vis_xi = torch.matmul(model1.W.T, vis_train_x2.T).detach().numpy()
    initial_txt_xi = torch.matmul(model2.W.T, txt_train_x2.T).detach().numpy()

    # vis_feature_learning[:, 0] = (torch.matmul(model1.W.T, vis_train_x1[1])).detach().numpy()
    # vis_noise_memorization[:, :, 0] = (torch.matmul(model1.W.T, vis_train_x2.T)).detach().numpy()
    #
    # txt_feature_learning[:, 0] = (torch.matmul(model2.W.T, txt_train_x1[0])).detach().numpy()
    # txt_noise_memorization[:, :, 0] = (torch.matmul(model2.W.T, txt_train_x2.T)).detach().numpy()


    test_loss_values = []

    for ep in range(n_epoch):
        train_loss = 0
        for (vis_sample_x1, vis_sample_x2, vis_sample_y), (txt_sample_x1, txt_sample_x2, txt_sample_y) in zip(data_loader1, data_loader2):
            bsize = vis_sample_y.size(0)

            model1.train()
            optimizer1.zero_grad()
            model2.train()
            optimizer2.zero_grad()

            mask = torch.logical_not(torch.eq(vis_sample_y.unsqueeze(1), txt_sample_y.unsqueeze(1).T).float())

            txt_f_pred_mu, txt_f_pred_xi = model2.forward(txt_sample_x1, txt_sample_x2)
            vis_f_pred_mu, vis_f_pred_xi = model1.forward(vis_sample_x1, vis_sample_x2)

            txt_sim_pos = torch.exp(torch.bmm(txt_f_pred_mu.unsqueeze(1),vis_f_pred_mu.detach().unsqueeze(1).transpose(1,2))/tau + \
                                torch.bmm(txt_f_pred_xi.unsqueeze(1),vis_f_pred_xi.detach().unsqueeze(1).transpose(1,2))/tau)
            txt_sim_neg_logits = torch.matmul(txt_f_pred_mu,vis_f_pred_mu.detach().T)+torch.matmul(txt_f_pred_xi,vis_f_pred_xi.detach().T)
            txt_sim_neg = torch.exp(mask*txt_sim_neg_logits).sum(1, keepdim=True)


            vis_sim_pos = torch.exp(torch.bmm(vis_f_pred_mu.unsqueeze(1),txt_f_pred_mu.detach().unsqueeze(1).transpose(1,2))/tau + \
                                torch.bmm(vis_f_pred_xi.unsqueeze(1),txt_f_pred_xi.detach().unsqueeze(1).transpose(1,2))/tau)
            vis_sim_neg_logits = torch.matmul(vis_f_pred_mu,txt_f_pred_mu.detach().T)+torch.matmul(vis_f_pred_xi,txt_f_pred_xi.detach().T)
            vis_sim_neg = torch.exp(mask*vis_sim_neg_logits).sum(1, keepdim=True)

            vis_sig_pos_similarity[:, ep] = torch.bmm(vis_f_pred_mu.detach().unsqueeze(1),txt_f_pred_mu.detach().unsqueeze(1).transpose(1,2)).detach().numpy().squeeze()
            vis_noi_pos_similarity[:, ep] = torch.bmm(vis_f_pred_xi.detach().unsqueeze(1),txt_f_pred_xi.detach().unsqueeze(1).transpose(1,2)).detach().numpy().squeeze()

            loss = -torch.log(txt_sim_pos/(txt_sim_pos+txt_sim_neg)).mean()-torch.log(vis_sim_pos/(vis_sim_pos+vis_sim_neg)).mean()

            if torch.isnan(loss).sum():
              break

            loss.backward()
            optimizer2.step()
            model2.eval()

            optimizer1.step()
            model1.eval()

        # evaluate on donwstream task
        with torch.no_grad():
            test_pred_mu, test_pred_xi = model1.forward(test_train_x1, test_train_x2)
            test2_pred_mu, test2_pred_xi = model1.forward(test_test_x1, test_test_x2)
            test_train_embed = test_pred_mu + test_pred_xi
            test_test_embed = test2_pred_mu + test2_pred_xi
            logistic_model = LogisticRegression(solver="liblinear")
            test_train_y_01 = test_train_y.cpu().numpy()
            test_train_y_01 = np.where(test_train_y_01 == -1, 0, test_train_y_01)
            logistic_model.fit(test_train_embed.cpu().numpy(), test_train_y_01)

            # test_test_pred = logistic_model.decision_function(test_test_embed.cpu().numpy())
            # test_test_y_numpy = test_test_y.cpu().numpy()
            # test_loss = np.log(1 + np.exp(- test_test_y_numpy * test_test_pred)).mean()

            # 0,1 loss
            test_test_pred = logistic_model.predict(test_test_embed.cpu().numpy())
            test_test_y_01 = test_test_y.cpu().numpy()
            test_test_y_01 = np.where(test_test_y_01 == -1, 0, test_test_y_01)
            test_loss = (test_test_pred == test_test_y_01).sum() / test_test_pred.shape[0]

        if torch.isnan(loss).sum():
            break
        vis_feature_learning[:, ep] = (torch.matmul(model1.W.T, vis_train_x1[0])).detach().numpy()
        vis_noise_memorization[:, :, ep] = (torch.matmul(model1.W.T, vis_train_x2.T)).detach().numpy()

        txt_feature_learning[:, ep] = (torch.matmul(model2.W.T, txt_train_x1[0])).detach().numpy()
        txt_noise_memorization[:, :, ep] = (torch.matmul(model2.W.T, txt_train_x2.T)).detach().numpy()

        train_loss_values.append(loss.item())
        test_loss_values.append(test_loss)

        print(f'[{ep + 1}|{n_epoch}] train_loss={loss:0.5e}, test_loss={test_loss:0.5e}')


    return np.array(vis_feature_learning), np.array(vis_noise_memorization), np.array(train_loss_values), np.array(test_loss_values)








def simclr_prepare_data():
    xi = 1.0
    mu = 5.0
    vis_train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    vis_test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))

    vis_feature = torch.zeros(d, 1)
    vis_feature[0] = mu
    vis_train_x1 = torch.matmul(vis_train_y.unsqueeze(0).T, vis_feature.T)
    vis_test_x1 = torch.matmul(vis_test_y.unsqueeze(0).T, vis_feature.T)
    vis_train_x2 = torch.randn(n_train, d)*xi
    vis_test_x2 = torch.randn(n_test, d)*xi
    #vis_train_x2[:,0] = 0

    return vis_train_x1, vis_train_x2, vis_train_y








def main_simclr(n_train, n_test, d, n_epoch, tau, width, eval_every):
    train_x1, train_x2, train_y = simclr_prepare_data()
    test_train_x1, test_train_x2, test_train_y, test_test_x1, test_test_x2, test_test_y = create_test_data(d)

    model = NN(m=width, d=d)
    sample_size = n_train
    data_loader = DataLoader(TensorDataset(
        train_x1,
        train_x2,
        train_y
    ), batch_size=int(n_train), shuffle=False)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    sim_train_loss_values = []
    sim_train_acc_values = []

    feature_learning = []

    sig_pos_similarity = np.zeros((n_train, n_epoch))
    noi_pos_similarity = np.zeros((n_train, n_epoch))

    noise_memorization = np.zeros((width, n_train, n_epoch))
    feature_learning = np.zeros((width, n_epoch))

    initial_mu = torch.matmul(model.W.T, train_x1[0]).detach().numpy()

    initial_xi = torch.matmul(model.W.T, train_x2.T).detach().numpy()

    sim_test_loss_values = []

    for ep in range(n_epoch):
        train_loss = 0
        for sample_x1, sample_x2, sample_y in data_loader:
            bsize = sample_y.size(0)

            model.train()
            optimizer.zero_grad()

            mask = torch.logical_not(torch.eq(sample_y.unsqueeze(1), sample_y.unsqueeze(1).T).float())

            f_pred_mu, f_pred_xi = model.forward(sample_x1, sample_x2)

            # obtain SimCLR augmented samples
            # mask_mu = torch.bernoulli(torch.zeros(bsize,d)+0.5)
            # mask_xi = torch.bernoulli(torch.zeros(bsize,d)+0.5)
            # mask_mu = mask_mu.float()
            # mask_xi = mask_xi.float()
            sample_x1, sample_x2 = sample_x1.float(), sample_x2.float()
            # aug_x1 = mask_mu*sample_x1
            # aug_x2 = mask_xi*sample_x2

            aug_x1 = sample_x1
            aug_x2 = sample_x2 + torch.randn_like(sample_x2) * 0.1
            f_pred_aug_mu, f_pred_aug_xi = model.forward(aug_x1, aug_x2)

            sim_pos = torch.exp(
                torch.bmm(f_pred_mu.unsqueeze(1), f_pred_aug_mu.detach().unsqueeze(1).transpose(1, 2)) / tau + \
                torch.bmm(f_pred_xi.unsqueeze(1), f_pred_aug_xi.detach().unsqueeze(1).transpose(1, 2)) / tau)

            sim_neg_logits = torch.matmul(f_pred_mu, f_pred_aug_mu.detach().T) + torch.matmul(f_pred_xi,
                                                                                          f_pred_aug_xi.detach().T)
            sim_neg = torch.exp(mask * sim_neg_logits).sum(1, keepdim=True)

            sig_pos_similarity[:, ep] = (torch.bmm(f_pred_mu.unsqueeze(1), f_pred_mu.detach().unsqueeze(1).transpose(1,
                                                                                                                     2))).detach().numpy().squeeze()
            noi_pos_similarity[:, ep] = (torch.bmm(f_pred_xi.unsqueeze(1), f_pred_xi.detach().unsqueeze(1).transpose(1,
                                                                                                                     2))).detach().numpy().squeeze()

            loss = -torch.log(sim_pos / (sim_pos + sim_neg)).mean()

            if torch.isnan(loss).sum():
                break

            loss.backward()
            optimizer.step()
            model.eval()


        # evaluate
        with torch.no_grad():
            test_pred_mu, test_pred_xi = model.forward(test_train_x1, test_train_x2)
            test2_pred_mu, test2_pred_xi = model.forward(test_test_x1, test_test_x2)
            test_train_embed = test_pred_mu + test_pred_xi
            test_test_embed = test2_pred_mu + test2_pred_xi
            logistic_model = LogisticRegression(solver="liblinear")
            test_train_y_01 = test_train_y.cpu().numpy()
            test_train_y_01 = np.where(test_train_y_01 == -1, 0, test_train_y_01)
            logistic_model.fit(test_train_embed.cpu().numpy(), test_train_y_01)

            # test_test_pred = logistic_model.decision_function(test_test_embed.cpu().numpy())
            # test_test_y_numpy = test_test_y.cpu().numpy()
            # test_loss = np.log(1 + np.exp(- test_test_y_numpy * test_test_pred)).mean()

            # 0,1 loss
            test_test_pred = logistic_model.predict(test_test_embed.cpu().numpy())
            test_test_y_01 = test_test_y.cpu().numpy()
            test_test_y_01 = np.where(test_test_y_01 == -1, 0, test_test_y_01)
            test_loss = (test_test_pred==test_test_y_01).sum()/test_test_pred.shape[0]


        if torch.isnan(loss).sum():
            break

        feature_learning[:, ep] = (torch.matmul(model.W.T, train_x1[0])).detach().numpy()
        noise_memorization[:, :, ep] = (torch.matmul(model.W.T, train_x2.T)).detach().numpy()

        sim_train_loss_values.append(loss.item())
        sim_test_loss_values.append(test_loss)

        print(f'[{ep + 1}|{n_epoch}] train_loss={loss:0.5e}, test_loss={test_loss:0.5e}')


    return np.array(feature_learning), np.array(noise_memorization), np.array(sim_train_loss_values), np.array(sim_test_loss_values)

















if __name__ == "__main__":

    n_train = 100
    n_test = 2000
    d = 2000
    n_epoch = 200
    tau = 1
    width = 50
    eval_every = 1

    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)

    repeat = 10

    clip_feature_all = np.zeros((repeat,n_epoch))
    sim_feature_all = np.zeros((repeat,n_epoch))
    clip_noise_all = np.zeros((repeat,n_epoch))
    sim_noise_all = np.zeros((repeat,n_epoch))

    clip_train_loss_values_all = np.zeros((repeat,n_epoch))
    sim_train_loss_values_all = np.zeros((repeat,n_epoch))
    clip_test_loss_values_all = np.zeros((repeat,n_epoch))
    sim_test_loss_values_all = np.zeros((repeat,n_epoch))

    for rep in range(repeat):

        clip_feature_learning, clip_noise_memorization, clip_train_loss_values, clip_test_loss_values = main_clip(n_train, n_test, d, n_epoch, tau, width, eval_every)
        sim_feature_learning, sim_noise_memorization, sim_train_loss_values, sim_test_loss_values = main_simclr(n_train, n_test, d, n_epoch, tau, width, eval_every)

        clip_feature = np.max(np.abs(clip_feature_learning), axis=0)
        sim_feature = np.max(np.abs(sim_feature_learning), axis=0)
        clip_noise = np.max(np.max(np.abs(clip_noise_memorization), axis=0), axis=0)
        sim_noise = np.max(np.max(np.abs(sim_noise_memorization), axis=0), axis=0)

        clip_feature_all[rep,:] = clip_feature
        sim_feature_all[rep,:] = sim_feature
        clip_noise_all[rep,:] = clip_noise
        sim_noise_all[rep,:] = sim_noise

        clip_train_loss_values_all[rep,:] = clip_train_loss_values
        sim_train_loss_values_all[rep,:] = sim_train_loss_values
        clip_test_loss_values_all[rep,:] = clip_test_loss_values
        sim_test_loss_values_all[rep,:] = sim_test_loss_values


    clip_feature_mean = np.mean(clip_feature_all, axis=0)
    clip_feature_std = np.std(clip_feature_all, axis=0)
    sim_feature_mean = np.mean(sim_feature_all, axis=0)
    sim_feature_std = np.std(sim_feature_all, axis=0)
    clip_noise_mean = np.mean(clip_noise_all, axis=0)
    clip_noise_std = np.std(clip_noise_all, axis=0)
    sim_noise_mean = np.mean(sim_noise_all, axis=0)
    sim_noise_std = np.std(sim_noise_all, axis=0)

    clip_train_loss_values_mean = np.mean(clip_train_loss_values_all, axis=0)
    clip_train_loss_values_std = np.std(clip_train_loss_values_all, axis=0)
    sim_train_loss_values_mean = np.mean(sim_train_loss_values_all, axis=0)
    sim_train_loss_values_std = np.std(sim_train_loss_values_all, axis=0)
    clip_test_loss_values_mean = np.mean(clip_test_loss_values_all, axis=0)
    clip_test_loss_values_std = np.std(clip_test_loss_values_all, axis=0)
    sim_test_loss_values_mean = np.mean(sim_test_loss_values_all, axis=0)
    sim_test_loss_values_std = np.std(sim_test_loss_values_all, axis=0)


    import matplotlib.pyplot as plt
    import matplotlib

    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42
    matplotlib.rcParams['font.size'] = 16


    # Create a 1x2 grid of subplots
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(24, 4))

    # plot 1: training loss
    ax1.plot(sim_train_loss_values_mean, label='Single-Modal', color='tab:blue')
    ax1.fill_between(np.arange(n_epoch), sim_train_loss_values_mean - sim_train_loss_values_std, sim_train_loss_values_mean + sim_train_loss_values_std,
                     color='tab:blue', alpha=0.2)
    ax1.plot(clip_train_loss_values_mean, label='Multi-Modal', color='tab:orange')
    ax1.fill_between(np.arange(n_epoch), clip_train_loss_values_mean - clip_train_loss_values_std,
                     clip_train_loss_values_mean + clip_train_loss_values_std,
                     color='tab:orange', alpha=0.2)

    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.legend()

    # plot 2: 0-1 loss
    ax2.plot(sim_test_loss_values_mean, label='Single-Modal', color='tab:blue')
    ax2.fill_between(np.arange(n_epoch), sim_test_loss_values_mean - sim_test_loss_values_std,
                     sim_test_loss_values_mean + sim_test_loss_values_std,
                     color='tab:blue', alpha=0.2)
    ax2.plot(clip_test_loss_values_mean, label='Multi-Modal', color='tab:orange')
    ax2.fill_between(np.arange(n_epoch), clip_test_loss_values_mean - clip_test_loss_values_std,
                     clip_test_loss_values_mean + clip_test_loss_values_std,
                     color='tab:orange', alpha=0.2)
    ax2.set_title('Test Accuracy')
    ax2.set_xlabel('Epoch')

    # plot 3: feature learning
    ax3.plot(sim_feature_mean, label='Single-Modal', color='tab:blue')
    ax3.fill_between(np.arange(n_epoch), sim_feature_mean - sim_feature_std,
                     sim_feature_mean + sim_feature_std,
                     color='tab:blue', alpha=0.2)
    ax3.plot(clip_feature_mean, label='Multi-Modal', color='tab:orange')
    ax3.fill_between(np.arange(n_epoch), clip_feature_mean - clip_feature_std,
                     clip_feature_mean + clip_feature_std,
                     color='tab:orange', alpha=0.2)
    ax3.set_title('Signal Learning')
    ax3.set_xlabel('Epoch')

    # plot 4: noise memorization
    ax4.plot(sim_noise_mean, label='Single-Modal', color='tab:blue')
    ax4.fill_between(np.arange(n_epoch), sim_noise_mean - sim_noise_std,
                     sim_noise_mean + sim_noise_std,
                     color='tab:blue', alpha=0.2)
    ax4.plot(clip_noise_mean, label='Multi-Modal', color='tab:orange')
    ax4.fill_between(np.arange(n_epoch), clip_noise_mean - clip_noise_std,
                     clip_noise_mean + clip_noise_std,
                     color='tab:orange', alpha=0.2)
    ax4.set_title('Noise Memorization')
    ax4.set_xlabel('Epoch')


    # plt.tight_layout()
    plt.savefig("clip_final.pdf",bbox_inches='tight')
    # plt.show()




