import torch
import torch.nn as nn


import torch
import torch.nn as nn
import torch.nn.functional as F

class PromptEncoder(nn.Module):
    def __init__(self, args, d, h, k, att_d, N, n, dropout= 0.5):
        super(PromptEncoder, self).__init__()
        self.args = args
        self.d = d
        self.h = h
        self.k = k
        self.att_d = att_d
        self.n = n
        self.dropout =dropout

        self.q_linear = nn.Linear(d, att_d)
        self.k_linear = nn.Linear(h * k, att_d)
        self.v_linear = nn.Linear(h * k, att_d)

        self.E = nn.Parameter(torch.ones(n, N)/n/N)
        self.F = nn.Parameter(torch.ones(n, N)/n/N)
        self.F = self.E
        # nn.init.zeros_(self.E)
        # nn.init.zeros_(self.F)

        # row_similarities = torch.matmul(self.E, self.E.transpose(-2, -1))
        # expected_similarities = torch.eye(self.n).to(row_similarities.device)
        # row_diversity_loss = self.row_diversity_weight * torch.sum((row_similarities - expected_similarities)**2)
        # self.E.grad = self.E.grad - 2 * self.row_diversity_weight * torch.matmul(self.E, row_similarities - expected_similarities)

        d_ff = att_d//2
        self.feed_forward = nn.Sequential(
            nn.Linear(att_d, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d)
        )


        self.layer_norm1 = nn.LayerNorm(att_d)
        self.layer_norm2 = nn.LayerNorm(d)
        # print(self.E.shape)

    def forward(self, Z1, Z2):
        Q = self.q_linear(Z1)  # (N, att_d)
        K = self.k_linear(Z2)  # (N, att_d)
        V = self.v_linear(Z2)  # (N, att_d)
        #V = Z1  # (N, att_d)

        if self.args.LR:
            # Low-Rank(·)
            K_low = torch.matmul(self.E, K)  # (n, att_d)
            V_low = torch.matmul(self.F, V)  # (n, att_d)
        else:
            K_low = K
            V_low = V

        attention_scores = torch.matmul(Q, K_low.transpose(-2, -1))  # (N, n)
        attention_scores = attention_scores / (self.att_d ** 0.5)
        attention_probs = F.softmax(attention_scores, dim=-1)  # (N, n)


        attn_output = torch.matmul(attention_probs, V_low)  # (N, att_d)
        attn_output = F.dropout(attn_output, p=self.dropout/10)

        output = attn_output

        # FFN
        ff_output = self.feed_forward(output)

        output = ff_output

        return output

if __name__ == '__main__':

    # PARAMETER
    N = 69499
    H = 32
    D_model = 32
    K = 2
    D = 128
    num_heads = 4
    D_ff = 512
    batch_size = 32

    # INPUT EXAMPLE
    node_features = torch.randn(N, D).to("cuda")
    prompt_embedding = torch.randn(N, H,K).to("cuda")
    prompt_embedding = prompt_embedding.view(N , H*K)

    import argparse
    parser = argparse.ArgumentParser(description="Description of your program")

    parser.add_argument("--LR", type=bool, default=True)

    args = parser.parse_args()
    model = PromptEncoder(args, D, H, K, num_heads, N, N//100).to("cuda")

    output = model(node_features, prompt_embedding)

    enhanced_features = node_features + output

    print(enhanced_features.shape)  # (69499, 128)

    from matplotlib import pyplot as plt
    import seaborn as sns

    E = model.E.detach().cpu().numpy()
    F = model.F.detach().cpu().numpy()

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(E, cmap='viridis')
    plt.title('Low-rank Matrix E')
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(F, cmap='viridis')
    plt.title('Low-rank Matrix F')
    plt.colorbar()

    plt.tight_layout()
    plt.show()

    import torch
    import matplotlib.pyplot as plt
    import numpy as np

    E = model.E

    U, S, V = torch.svd(E, some=True)

    singular_values = S.detach().cpu().numpy()

    plt.figure(figsize=(8, 4))
    plt.plot(singular_values, marker='o', linestyle='-', linewidth=1)
    plt.title('Singular Values of Low-rank Matrix E')
    plt.xlabel('Singular Value Index')
    plt.ylabel('Singular Value')
    plt.grid(True)
    plt.show()

    print(singular_values[:10])
    cumulative_variance = np.cumsum(singular_values) / np.sum(singular_values)
    plt.figure(figsize=(8, 4))
    plt.plot(cumulative_variance, marker='o', linestyle='-', linewidth=1)
    plt.title('Cumulative Proportion of Singular Values')
    plt.xlabel('Singular Value Index')
    plt.ylabel('Cumulative Proportion')
    plt.grid(True)
    plt.show()
