##########################################################################################
# Machine Environment Config
USE_CUDA = True
CUDA_DEVICE_NUM = 0
##########################################################################################
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from torch.utils.data import Dataset
import random
from torch import nn
import torch.nn.functional  as F
import pickle
import random
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import MultipleLocator
from math import ceil
from torch.optim.lr_scheduler import MultiStepLR
from collections import namedtuple
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.size'] = 16
import math

class Transformer(nn.Module):
    def __init__(self, d, m, sigma_0, sigma_1):
        super(Transformer, self).__init__()
        self.d = d
        self.m = m
        self.attention = None
        self.W_K = nn.Parameter(torch.eye(d) * sigma_0, requires_grad=True)
        self.W_Q = nn.Parameter(torch.eye(d) * sigma_0, requires_grad=True)
        self.W_O = nn.Parameter(torch.randn(m, d) * sigma_1, requires_grad=True)
        self.a = nn.Parameter(torch.tensor([1/m] * (m//2) + [-1/m] * (m//2)), requires_grad=False)
        
    def forward(self, x, y, K1, is_correct_sample, is_k_topic):
        """
        Args:
            x (torch.Tensor): The x sequences of the prompts, shape (n, L+1, d).
            y (torch.Tensor): The y sequences of the prompts, shape (n, L+1, d).
        
        Returns:
            output (torch.Tensor): The output of the Transformer model, shape (n, m).
        """
        n, L, _ = x.size()
        
        # Compute the attention
        key = torch.einsum('nld,md->nlm', x[:, :-1, :], self.W_K)
        query = torch.einsum('njd,md->njm', x[:, -1, :].unsqueeze(1), self.W_Q)
        value = y[:, :-1, :]
        # value = torch.cat((x[:, :-1, :], y[:, :-1, :]), dim=-1)
        attention = torch.einsum('njd,nld->nl', query, key)
        attention = F.softmax(attention, dim=-1)
        context = torch.einsum('nl,nld->nd', attention, value)
        
        # Apply the MLP
        output = torch.einsum('nm,m->n', F.relu(torch.einsum('nd,md->nm', context, self.W_O)), self.a)

        # Calculate the correct attention weights for each topic
        correct_attention_weights = []
        for k in range(K1):
            k_topic_idx = is_k_topic[:, k]
            correct_k_topic_idx = (k_topic_idx & is_correct_sample.any(dim=1)).nonzero(as_tuple=True)[0]
            if len(correct_k_topic_idx) > 0:
                correct_k_topic_attention = []
                for n in correct_k_topic_idx:
                    correct_l = is_correct_sample[n, :].nonzero(as_tuple=True)[0]
                    if len(correct_l) > 0:
                        correct_k_topic_attention.append(attention[n, correct_l].sum())
                    else:
                        correct_k_topic_attention.append(torch.tensor(0.0))
                correct_k_topic_attention = torch.stack(correct_k_topic_attention)
                correct_attention_weights.append(correct_k_topic_attention.mean().cpu().detach().numpy())
            else:
                correct_attention_weights.append(0.0)
        return output, correct_attention_weights

def dictionary_M(u, d, K1, K2, theta):
    """
    Construct the dictionary matrix M.
    
    Args:
        u (float): Magnitude of the column vectors in M.
        d (int): dimension of the column vectors.
        K1 (int): number of topics.
        K2 (int): number of topic-irrelevant vectors.
        theta (float): Cosine similarity between vectors of the same topic.
        
    Returns:
        M (torch.Tensor): The dictionary matrix M of shape (2*K1 + K2, d).
    """
    
    M = torch.zeros(2*K1 + K2, d)
    
    # Generate the first 2*K1 column vectors
    for k in range(K1):
        # Semantic vectors for the same topic
        a_norm = u * (math.sqrt((1 + theta)/2))
        b_norm = u * (math.sqrt((1 - theta)/2))
        M[2*k, :] = 0
        M[2*k+1,:] = 0
        M[2*k, 2*k] = a_norm
        M[2*k+1, 2*k] = a_norm
        M[2*k, 2*k+1] = b_norm
        M[2*k+1, 2*k+1] = -b_norm
    
    # Generate the topic-irrelevant K_2 vectors
    for k in range(K2):
        M[2*K1 + k, 2*K1 + k] = u
        
        # Ensure orthogonality with the topic-relevant vectors
        M[2*K1 + k, :2*K1] = 0
    
    return M

def dictionary_Q(q, d, K1, K2, theta):
    """
    Construct the dictionary matrix Q.
    
    Args:
        q (float): Magnitude of the column vectors in Q.
        d (int): dimension of the column vectors.
        K1 (int): number of topics.
        K2 (int): number of topic-irrelevant vectors.
        theta (float): Cosine similarity between vectors of the same topic.
        
    Returns:
        Q (torch.Tensor): The dictionary matrix Q of shape (2*K1 + K2, d).
    """
    
    Q = torch.zeros(2*K1 + K2, d)
    
    # Generate the first 2*K1 column vectors
    for k in range(K1):
        # Semantic vectors for the same topic
        a_norm = q * (math.sqrt((1 + theta)/2))
        b_norm = q * (math.sqrt((1 - theta)/2))
        Q[2*k, :] = 0
        Q[2*k+1,:] = 0
        Q[2*k, 2*k] = a_norm
        Q[2*k+1, 2*k] = a_norm
        Q[2*k, 2*k+1] = b_norm
        Q[2*k+1, 2*k+1] = -b_norm
    
    Q[2*K1+1:, :] = 0
    
    return Q

def validate(M_concept, M_semantic, Q_concept, Q_semantic, K1):
    for k in range(K1):

        non_zero_dims = (M_concept[k] != 0).sum()
        if non_zero_dims != 1 or M_concept[k][M_concept[k] != 0][0] <= 0:
            print(f"M_concept[{k}] is not valid: non_zero_dims={non_zero_dims}, value={M_concept[k][M_concept[k] != 0][0]}")
        
        non_zero_dims = (M_semantic[k] != 0).sum()
        if non_zero_dims != 1 or M_semantic[k][M_semantic[k] != 0][0] <= 0:
            print(f"M_semantic[{k}] is not valid: non_zero_dims={non_zero_dims}, value={M_semantic[k][M_semantic[k] != 0][0]}")
        
        non_zero_dims = (Q_concept[k] != 0).sum()
        if non_zero_dims != 1 or Q_concept[k][Q_concept[k] != 0][0] <= 0:
            print(f"Q_concept[{k}] is not valid: non_zero_dims={non_zero_dims}, value={Q_concept[k][Q_concept[k] != 0][0]}")
        
        non_zero_dims = (Q_semantic[k] != 0).sum()
        if non_zero_dims != 1 or Q_semantic[k][Q_semantic[k] != 0][0] <= 0:
            print(f"Q_semantic[{k}] is not valid: non_zero_dims={non_zero_dims}, value={Q_semantic[k][Q_semantic[k] != 0][0]}")

def compute_M_concept_and_semantic(M, Q, K1, K2, d):
    M_concept = torch.zeros(K1, d)
    M_semantic = torch.zeros(K1, d)
    Q_concept = torch.zeros(K1, d)
    Q_semantic = torch.zeros(K1, d)

    for k in range(K1):
        M_concept[k] = (M[2*k, :] + M[2*k+1, :]) / 2
        M_semantic[k] = (M[2*k, :] - M[2*k+1, :]) / 2 
        Q_concept[k] = (Q[2*k, :] + Q[2*k+1, :]) / 2 
        Q_semantic[k] = (Q[2*k, :] - Q[2*k+1, :]) / 2 
    M_vectors = torch.cat((M_concept, M_semantic), dim=0)

    Q_vectors = torch.cat((Q_concept, Q_semantic), dim=0)

    M_irrelevant = M[2*K1:K2]

    _, _, Vh = torch.linalg.svd(torch.cat((M_vectors, M_irrelevant), dim=0), full_matrices=True)
    Complement_space_M = Vh[2*K1+K2:]

    _, _, Vh = torch.linalg.svd(Q_vectors, full_matrices=True)
    Complement_space_Q = Vh[2*K1:]

    return M_concept, M_semantic, Q_concept, Q_semantic, Complement_space_M, Complement_space_Q, M_irrelevant

def compute_W_Q_K_concept_semantic(model, M_concept, M_semantic, Q_concept, Q_semantic, Complement_space_M, Complement_space_Q, M_irrelevant, x_train_noise, y_train_noise, device):

    W_Q_concept = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_Q.to(device), M_concept.to(device)), M_concept.to(device))
    W_K_concept = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_K.to(device), M_concept.to(device)), M_concept.to(device))
    W_Q_semantic = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_Q.to(device), M_semantic.to(device)), M_semantic.to(device))
    W_K_semantic = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_K.to(device), M_semantic.to(device)), M_semantic.to(device))
    W_Q_noise_memorize = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_Q.to(device), Complement_space_M.to(device)), Complement_space_M.to(device))
    W_K_noise_memorize = torch.einsum('dk, kd->k', torch.einsum('md,kd->mk', model.W_K.to(device), Complement_space_M.to(device)), Complement_space_M.to(device))

    W_O_concept = torch.einsum('md,kd->mk', model.W_O.to(device), Q_concept.to(device))
    W_O_semantic = torch.einsum('md,kd->mk', model.W_O.to(device), Q_semantic.to(device))
    W_O_noise_memorize = torch.einsum('md,kd->mk', model.W_O.to(device), Complement_space_Q.to(device))

    W_O_concept_max_idx = torch.argmax(W_O_concept.abs(), dim=0)
    W_O_semantic_max_idx = torch.argmax(W_O_semantic.abs(), dim=0)
    W_O_noise_memorize_max_idx = torch.argmax(W_O_noise_memorize.abs(), dim=0)
    W_O_concept = torch.gather(W_O_concept, 0, W_O_concept_max_idx.unsqueeze(0)).squeeze(0)
    W_O_semantic = torch.gather(W_O_semantic, 0, W_O_semantic_max_idx.unsqueeze(0)).squeeze(0)
    W_O_noise_memorize_max = torch.gather(W_O_noise_memorize, 0, W_O_noise_memorize_max_idx.unsqueeze(0)).squeeze(0).abs()
    W_O_noise_memorize_mean = W_O_noise_memorize.abs().mean(dim=0)

    W_noise_vec_memo = torch.einsum('kd, md->km', torch.einsum('mk, dm->kd', torch.einsum('md,kd->mk', model.W_Q.to(device), x_train_noise[:, -1, :].reshape(-1, d).to(device)), model.W_K.to(device)), x_train_noise[:, :-1, :].reshape(-1, d).to(device))
    W_noise_vec_memo, _ = torch.max(W_noise_vec_memo, dim=-1)
    W_irrelevant_product = torch.einsum('kd, md->km', torch.einsum('mk, dm->kd', torch.einsum('md,kd->mk', model.W_Q.to(device), M_irrelevant.to(device)), model.W_K.to(device)), M_irrelevant.reshape(-1, d).to(device))
    W_irrelevant_product, _ = torch.max(W_irrelevant_product, dim=-1)

    W_O_noise_vec_memo = torch.einsum('md,kd->mk', model.W_O.to(device), y_train_noise.reshape(-1, d).to(device))
    W_O_noise_vec_memo_max_idx = torch.argmax(W_O_noise_vec_memo.abs(), dim=0)
    W_O_noise_vec_memo_max = torch.gather(W_O_noise_vec_memo, 0, W_O_noise_vec_memo_max_idx.unsqueeze(0)).squeeze(0).abs()
    W_O_noise_vec_memo_mean = W_O_noise_vec_memo.abs().mean(dim=0)

    return W_Q_concept, W_K_concept, W_Q_semantic, W_K_semantic, W_Q_noise_memorize, W_K_noise_memorize, W_O_concept, W_O_semantic, W_O_noise_memorize_max, W_O_noise_memorize_mean, W_O_noise_vec_memo_max, W_O_noise_vec_memo_mean, W_noise_vec_memo, W_irrelevant_product

class Tracker:
    def __init__(self):
        self.W_Q_concept_list = []
        self.W_K_concept_list = []
        self.W_Q_semantic_list = []
        self.W_K_semantic_list = []
        self.W_O_concept_list = []
        self.W_O_semantic_list = []
        self.W_K_grad_norm_list = []
        self.W_Q_grad_norm_list = []
        self.W_O_grad_norm_list = []
        self.W_Q_noise_memorize_list = []
        self.W_K_noise_memorize_list = []
        self.W_O_noise_memorize_max_list = []
        self.W_O_noise_memorize_mean_list = []
        self.W_O_noise_vec_memo_mean_list = []
        self.W_noise_vec_memo_list = []
        self.W_O_noise_vec_memo_max_list = []
        self.W_O_noise_vec_memo_mean_list = []
        self.W_irrelevant_product_list = []
    def update_grad_norms(self, W_K_grad, W_Q_grad, W_O_grad):
        self.W_K_grad_norm_list.append(W_K_grad)
        self.W_Q_grad_norm_list.append(W_Q_grad)
        self.W_O_grad_norm_list.append(W_O_grad)
    def update_values(self, W_Q_concept, W_K_concept, W_Q_semantic, W_K_semantic, 
                    W_Q_noise_memorize, W_K_noise_memorize, W_O_concept, W_O_semantic, 
                    W_O_noise_memorize_max, W_O_noise_memorize_mean, W_noise_vec_memo, W_O_noise_vec_memo_max, W_O_noise_vec_memo_mean, W_irrelevant_product):
        self.W_Q_concept_list.append(W_Q_concept.cpu().detach().numpy())
        self.W_K_concept_list.append(W_K_concept.cpu().detach().numpy())
        self.W_Q_semantic_list.append(W_Q_semantic.cpu().detach().numpy())
        self.W_K_semantic_list.append(W_K_semantic.cpu().detach().numpy())
        self.W_Q_noise_memorize_list.append(W_Q_noise_memorize.cpu().detach().numpy())
        self.W_K_noise_memorize_list.append(W_K_noise_memorize.cpu().detach().numpy())

        self.W_O_concept_list.append(W_O_concept.cpu().detach().numpy())
        self.W_O_semantic_list.append(W_O_semantic.cpu().detach().numpy())
        self.W_O_noise_memorize_max_list.append(W_O_noise_memorize_max.cpu().detach().numpy())
        self.W_O_noise_memorize_mean_list.append(W_O_noise_memorize_mean.cpu().detach().numpy())
        self.W_noise_vec_memo_list.append(W_noise_vec_memo.cpu().detach().numpy())
        self.W_O_noise_vec_memo_max_list.append(W_O_noise_vec_memo_max.cpu().detach().numpy())
        self.W_O_noise_vec_memo_mean_list.append(W_O_noise_vec_memo_mean.cpu().detach().numpy())
        self.W_irrelevant_product_list.append(W_irrelevant_product.cpu().detach().numpy())

def plot_learning_curves(train_losses, train_errors, test_losses, tracker, training_epoch, train_attention_weights, test_attention_weights, batch_size, weight_decay, gamma):
    # Plot the training loss and test 0-1 loss
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))

    axes[0].plot(train_losses, label='Train Loss per batch', linewidth=4)
    axes[0].plot(test_losses, label='Test 0-1 Loss', linewidth=4)
    axes[0].set_xlabel('Epoch', fontsize=20)
    axes[0].set_ylabel('Loss', fontsize=20)
    axes[0].set_title('Training and Test Loss', fontsize=25)
    axes[0].legend(fontsize=20)
    axes[0].tick_params(axis='both', which='major', labelsize=20)
    axes[0].grid()

    num_concepts = np.array(tracker.W_Q_concept_list).shape[1]

    axes[1].plot(range(training_epoch), train_attention_weights, label=[r'train attn weights ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[1].plot(range(training_epoch), test_attention_weights, label=[r'test attn weights ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[1].legend(fontsize=15)
    axes[1].set_title('Evolution of Attention Weights', fontsize=20)
    axes[1].set_xlabel('Epoch', fontsize=18)
    axes[1].set_xticks(range(0, training_epoch, 10))
    axes[1].tick_params(axis='both', which='major', labelsize=18)

    axes[2].plot(range(training_epoch), np.multiply(np.array(tracker.W_Q_concept_list), np.array(tracker.W_K_concept_list)), label=[r'$\alpha_Q*\alpha_K$ ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[2].plot(range(training_epoch), np.multiply(np.array(tracker.W_Q_semantic_list), np.array(tracker.W_K_semantic_list)), label=[r'$\beta_Q*\beta_K$ ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[2].plot(range(training_epoch), np.multiply(np.array(tracker.W_Q_noise_memorize_list), np.array(tracker.W_K_noise_memorize_list)).max(axis=-1), label='Complement Product Max', linewidth=4)
    axes[2].plot(range(training_epoch), np.array(tracker.W_noise_vec_memo_list).max(axis=-1), label='Noise Product Max', linewidth=4)
    axes[2].legend(fontsize=15)
    axes[2].set_title('Learning Progress of Attention', fontsize=20)
    axes[2].set_xlabel('Epoch', fontsize=18)
    axes[2].set_xticks(range(0, training_epoch, 10))
    axes[2].tick_params(axis='both', which='major', labelsize=18)

    axes[3].plot(range(training_epoch), np.array(tracker.W_O_concept_list), label=[r'$\alpha_O^y$ ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[3].plot(range(training_epoch), np.abs(np.array(tracker.W_O_semantic_list)), label=[r'|$\beta_O^y$| ({})'.format(i) for i in range(num_concepts)], linewidth=4)
    axes[3].plot(range(training_epoch), np.array(tracker.W_O_noise_memorize_max_list).mean(axis=-1), label='Complement Coeff Max', linewidth=4)
    axes[3].plot(range(training_epoch), np.array(tracker.W_O_noise_vec_memo_mean_list).mean(axis=-1), label='Noise Memo Mean', linewidth=4)
    axes[3].legend(fontsize=15)
    axes[3].set_title('Learning Progress of MLP', fontsize=20)
    axes[3].set_xlabel('Epoch', fontsize=18)
    axes[3].set_xticks(range(0, training_epoch, 10))
    axes[3].tick_params(axis='both', which='major', labelsize=18)

    fig.tight_layout()
    plt.savefig(f'figures_SGD-bs-{batch_size}-wd-{weight_decay}-gm-{gamma}.png', dpi=300)
    plt.close()

    plt.close()

def concept_specific_prompt(M, Q, K1, K2, d, L, noise_sigma, n):
    """
    Generate n prompts from the mixed model, with exactly n // K1 prompts per topic, and equal number of 1's in z[:, :, 2*k] and z[:, :, 2*k+1].
    
    Args:
        M (torch.Tensor): The dictionary matrix M, shape (d, 2*K1).
        Q (torch.Tensor): The dictionary matrix Q, shape (d, 2*K1).
        K1 (int): number of topics.
        K2 (int): number of topic-irrelevant vectors.
        d (int): dimension of the column vectors.
        L (int): Length of the (x, y) sequence is L+1. The sequence is L+1 length, the (x_{L+1}, y_{L+1}) is the query.
        noise_sigma (float): Gaussian noise of x and y.
        n (int): Number of prompts to generate.
        
    Returns:
        x (torch.Tensor): The x sequences of the prompts, shape (n, L+1, d).
        y (torch.Tensor): The y sequences of the prompts, shape (n, L+1, d).
        z (torch.Tensor): The sparse latent variables z, shape (n, L+1, 2*K1 + K2).
    """
    
    # Determine the number of prompts per topic
    n_per_topic = n // K1
    is_k_topic = torch.zeros(n, K1, dtype=torch.bool) 
    is_correct_sample = torch.zeros(n, L, dtype=torch.bool)
    # Generate the z tensor
    z = torch.zeros(n, L+1, 2*K1 + K2)
    labels = torch.zeros(n)
    n_k_half = n_per_topic // 2
    # Fraction or Probability
    for k in range(K1):
        for n_k in range(k*n_per_topic, k*n_per_topic+n_k_half):
            z[n_k, L, :] = 0
            z[n_k, L, 2*k] = 1 
            z[n_k, L, 2*k+1] = 0
            labels[n_k] = 1
            for k_prime in range(K1): 
                if k_prime != k:
                    if torch.rand(1) < 2/((2*K1 + K2)):
                        if torch.rand(1) < 1 / 2:
                            z[n_k, L, 2*k_prime] = 1
                            z[n_k, L, 2*k_prime+1] = 0
                        else:
                            z[n_k, L, 2*k_prime] = 0
                            z[n_k, L, 2*k_prime+1] = 1
        for n_k in range(k*n_per_topic+n_k_half, (k+1)*n_per_topic):
            z[n_k, L, :] = 0
            z[n_k, L, 2*k] = 0
            z[n_k, L, 2*k+1] = 1 
            labels[n_k] = -1
            for k_prime in range(K1):
                if k_prime != k:
                    if torch.rand(1) < 2/((2*K1 + K2)):
                        if torch.rand(1) < 1 / 2:
                            z[n_k, L, 2*k_prime] = 1
                            z[n_k, L, 2*k_prime+1] = 0
                        else:
                            z[n_k, L, 2*k_prime] = 0
                            z[n_k, L, 2*k_prime+1] = 1
        for n_k in range(k*n_per_topic, (k+1)*n_per_topic):
            is_k_topic[n_k, k] = True
            for l_n_k in range(L//2): 
                z[n_k, l_n_k, :] = 0
                z[n_k, l_n_k, 2*k] = 1 
                z[n_k, l_n_k, 2*k+1] = 0
                if n_k in range(k*n_per_topic, k*n_per_topic+n_k_half):
                    is_correct_sample[n_k, l_n_k] = True
                else:
                    is_correct_sample[n_k, l_n_k] = False
                for k_prime in range(K1):
                    if k_prime != k:
                        if torch.rand(1) < 2/((2*K1 + K2)):
                            if torch.rand(1) < 1 / 2:
                                z[n_k, l_n_k, 2*k_prime] = 1
                                z[n_k, l_n_k, 2*k_prime+1] = 0
                            else:
                                z[n_k, l_n_k, 2*k_prime] = 0
                                z[n_k, l_n_k, 2*k_prime+1] = 1
            for l_n_k in range(L//2, L):
                z[n_k, l_n_k, :] = 0
                z[n_k, l_n_k, 2*k + 1] = 1 
                z[n_k, l_n_k, 2*k] = 0
                if n_k in range(k*n_per_topic, k*n_per_topic+n_k_half):
                    is_correct_sample[n_k, l_n_k] = False
                else:
                    is_correct_sample[n_k, l_n_k] = True
                for k_prime in range(K1):
                    if k_prime != k:
                        if torch.rand(1) < 2/((2*K1 + K2)):
                            if torch.rand(1) < 1 / 2:
                                z[n_k, l_n_k, 2*k_prime] = 1
                                z[n_k, l_n_k, 2*k_prime+1] = 0
                            else:
                                z[n_k, l_n_k, 2*k_prime] = 0
                                z[n_k, l_n_k, 2*k_prime+1] = 1
    
    # Set the remaining z elements randomly
    z[:, :, 2*K1:] = (torch.rand(n, L+1, K2) < 1/((2*K1 + K2))).float()
    
    # Generate the x and y sequences
    x_noise = torch.randn(n, L+1, d) * noise_sigma
    x = torch.einsum('kd,nlk->nld', M, z) + x_noise
    y_noise = torch.randn(n, L+1, d) * noise_sigma
    y = torch.einsum('kd,nlk->nld', Q, z) + y_noise

    return x, y, z, labels, x_noise, y_noise, is_correct_sample, is_k_topic

Prompt = namedtuple('Prompt', ['x', 'y', 'labels', 'is_correct_sample', 'is_k_topic'])

class ConceptSpecificPromptDataset(Dataset):
    def __init__(self, x, y, labels, is_correct_sample, is_k_topic):
        assert x.size(0) == y.size(0) == labels.size(0) == is_correct_sample.size(0) == is_k_topic.size(0), \
            "All input tensors must have the same first dimension (number of samples)"
        self.x = x
        self.y = y
        self.labels = labels
        self.is_correct_sample = is_correct_sample
        self.is_k_topic = is_k_topic

    def __len__(self):
        return self.x.size(0)

    def __getitem__(self, idx):
        if idx >= len(self.x):
            raise IndexError(f"Index {idx} is out of range for dataset of length {len(self.x)}")
        return self.x[idx], self.y[idx], self.labels[idx], self.is_correct_sample[idx], self.is_k_topic[idx]

def adjust_learning_rate(optimizer, epoch, initial_lr):
    """Decay the learning rate based on epoch"""
    if epoch < T1:
        lr = initial_lr
    else:
        lr = 2/ (weight_decay * (epoch + 1 + gamma))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr
    
def main():

    M = dictionary_M(u, d, K1, K2, theta)
    Q = dictionary_Q(q, d, K1, K2, theta)

    x_train, y_train, _, labels_train, x_train_noise, y_train_noise, is_correct_sample_train, is_k_topic_train = concept_specific_prompt(M, Q, K1, K2, d, L, noise_sigma, n_train)
    indices = torch.randperm(n_train)
    x_train = x_train[indices]
    y_train = y_train[indices]
    labels_train = labels_train[indices]
    is_correct_sample_train = is_correct_sample_train[indices]
    is_k_topic_train = is_k_topic_train[indices]
    x_test, y_test, z_test, labels_test, x_test_noise, y_test_noise, is_correct_sample_test, is_k_topic_test = concept_specific_prompt(M, Q, K1, K2, d, L, noise_sigma, n_test)

    train_dataset = ConceptSpecificPromptDataset(x_train, y_train, labels_train, is_correct_sample_train, is_k_topic_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = ConceptSpecificPromptDataset(x_test, y_test, labels_test, is_correct_sample_test, is_k_topic_test)
    test_loader = DataLoader(test_dataset, batch_size=100000, shuffle=False)

    model = Transformer(d=d, m=m, sigma_0=sigma_0, sigma_1=sigma_1).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    tracker = Tracker()

    M_concept, M_semantic, Q_concept, Q_semantic, Complement_space_M, Complement_space_Q, M_irrelevant = compute_M_concept_and_semantic(M, Q, K1, K2, d)
    validate(M_concept, M_semantic, Q_concept, Q_semantic, K1)
    W_Q_concept, W_K_concept, W_Q_semantic, W_K_semantic, W_Q_noise_memorize, W_K_noise_memorize, W_O_concept, W_O_semantic, W_O_noise_memorize_max, W_O_noise_memorize_mean, W_O_noise_vec_memo_max, W_O_noise_vec_memo_mean, W_noise_vec_memo, W_irrelevant_product = compute_W_Q_K_concept_semantic(model, M_concept, M_semantic, Q_concept, Q_semantic, Complement_space_M, Complement_space_Q, M_irrelevant, x_test_noise, y_test_noise, device)

    print(f"Initial beta_Q * beta_K 0: {(W_Q_semantic * W_K_semantic)[0].item():.4f}, beta_Q * beta_K 1: {(W_Q_semantic * W_K_semantic)[1].item():.4f}")

    train_losses = []
    train_errors = []
    test_losses = []
    train_attention_weights = []
    test_attention_weights = []
    current_lr = lr

    for epoch in range(training_epoch):
        current_lr = adjust_learning_rate(optimizer, epoch, current_lr)

        model.train()
        total_train_loss = 0

        start_index = epoch * batch_size
        end_index = min((epoch + 1) * batch_size, n_train)
        
        x, y, labels, is_correct_sample_train, is_k_topic_train = x_train[start_index:end_index].to(device), y_train[start_index:end_index].to(device), labels_train[start_index:end_index].to(device), is_correct_sample_train[start_index:end_index].to(device), is_k_topic_train[start_index:end_index].to(device)
        optimizer.zero_grad()
        output, _ = model(x, y, K1=K1, is_correct_sample=is_correct_sample_train, is_k_topic=is_k_topic_train)
        loss = torch.log(torch.add(torch.exp(-labels * output), 1))
        loss = loss.mean()
        loss += 0.5 * weight_decay * (model.W_O.norm() ** 2 + model.W_Q.norm() ** 2 + model.W_K.norm() ** 2)
        loss.backward()
        W_K_grad_norm = model.W_K.grad.norm().item()
        W_Q_grad_norm = model.W_Q.grad.norm().item()
        W_O_grad_norm = model.W_O.grad.norm().item()
        optimizer.step()
        total_train_loss += loss.item()

        train_loss = total_train_loss
        train_losses.append(train_loss)
        tracker.update_grad_norms(W_K_grad_norm, W_Q_grad_norm, W_O_grad_norm)
        model.eval()
        total_train_error = 0
        total_train_correct_weight = [0.0, 0.0]
        with torch.no_grad():
            for x, y, labels, is_correct_sample_train, is_k_topic_train in train_loader:
                x, y, labels, is_correct_sample_train, is_k_topic_train = x.to(device), y.to(device), labels.to(device), is_correct_sample_train.to(device), is_k_topic_train.to(device)
                output, train_correct_attention_weights = model(x, y, K1=K1, is_correct_sample=is_correct_sample_train, is_k_topic=is_k_topic_train)
                train_error = ((output * labels) < 0).float().mean()
                total_train_error += train_error.item()
                if len(total_train_correct_weight) == len(train_correct_attention_weights):
                    total_train_correct_weight = [a + b for a, b in zip(total_train_correct_weight, train_correct_attention_weights)]
                else:
                    print("Error: Lists do not have the same length.")
        error = total_train_error / len(train_loader)
        train_errors.append(error)
        train_correct_weight = [x / len(train_loader) for x in total_train_correct_weight]
        train_attention_weights.append(train_correct_weight)

        model.eval()
        total_test_loss = 0
        total_test_correct_weight = [0.0, 0.0]
        with torch.no_grad():
            for x, y, labels, is_correct_sample_test, is_k_topic_test in test_loader:
                x, y, labels, is_correct_sample_test, is_k_topic_test = x.to(device), y.to(device), labels.to(device), is_correct_sample_test.to(device), is_k_topic_test.to(device)
                output, test_correct_attention_weights = model(x, y, K1=K1, is_correct_sample=is_correct_sample_test, is_k_topic=is_k_topic_test)

                test_loss = ((output * labels) < 0).float().mean()
                total_test_loss += test_loss.item()
                if len(total_test_correct_weight) == len(test_correct_attention_weights):
                    total_test_correct_weight = [a + b for a, b in zip(total_test_correct_weight, test_correct_attention_weights)]
                else:
                    print("Error: Lists do not have the same length.")
        test_correct_weight = [x / len(test_loader) for x in total_test_correct_weight]
        test_attention_weights.append(test_correct_weight)
        test_loss = total_test_loss / len(test_loader)
        test_losses.append(test_loss)

        W_Q_concept, W_K_concept, W_Q_semantic, W_K_semantic, W_Q_noise_memorize, W_K_noise_memorize, W_O_concept, W_O_semantic, W_O_noise_memorize_max, W_O_noise_memorize_mean, W_O_noise_vec_memo_max, W_O_noise_vec_memo_mean, W_noise_vec_memo, W_irrelevant_product = compute_W_Q_K_concept_semantic(model, M_concept, M_semantic, Q_concept, Q_semantic, Complement_space_M, Complement_space_Q, M_irrelevant, x_test_noise, y_test_noise, device)

        print(f"Epoch [{epoch+1}/{training_epoch}], Train Loss: {train_loss:.4f}, Test 0-1 Loss: {test_loss:.4f}, train_correct_weight 0:{train_correct_attention_weights[0]:.4f}, test_correct_weight 0: {test_correct_attention_weights[0]:.4f}, train_correct_weight 1:{train_correct_attention_weights[1]:.4f}, test_correct_weight 1: {test_correct_attention_weights[1]:.4f}, W_irrelevant_product 0: {W_irrelevant_product[0]:.4f}, W_irrelevant_product 1: {W_irrelevant_product[1]:.4f}, W_K_grad_norm: {W_K_grad_norm:.4f}, W_Q_grad_norm: {W_Q_grad_norm:.4f}, W_O_grad_norm: {W_O_grad_norm:.4f}, beta_Q * beta_K 0: {(W_Q_semantic * W_K_semantic)[0].item():.4f}, beta_Q * beta_K 1: {(W_Q_semantic * W_K_semantic)[1].item():.4f}, W_O_concept 0 :{W_O_concept[0].item():.4f}, W_O_semantic 0:{W_O_semantic[0].item():.4f}, W_O_concept 1 :{W_O_concept[1].item():.4f}, W_O_semantic 1:{W_O_semantic[1].item():.4f}, W_Q_noise_memorize:{W_Q_noise_memorize.abs().mean().item():.4f}, W_K_noise_memorize:{W_K_noise_memorize.abs().mean().item():.4f}, W_O_noise_memorize_max:{W_O_noise_memorize_max.mean().item():.4f}")

        tracker.update_values(W_Q_concept, W_K_concept, W_Q_semantic, W_K_semantic, 
                    W_Q_noise_memorize, W_K_noise_memorize, W_O_concept, W_O_semantic, 
                    W_O_noise_memorize_max, W_O_noise_memorize_mean, W_noise_vec_memo, W_O_noise_vec_memo_max, W_O_noise_vec_memo_mean, W_irrelevant_product)

    plot_learning_curves(train_losses, train_errors, test_losses, tracker, training_epoch, train_attention_weights, test_attention_weights, batch_size, weight_decay, gamma)

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #training script
    # Parameters of instance
    n_test = 5000
    d = 1000
    m = 50
    theta = 0.5
    K1 = 2
    K2 = 100
    u = 10
    q = 10
    sigma_0 = 0.1
    sigma_1 = 0.01
    noise_sigma = 0.01
    L = 4  # Length of the (x, y) sequence
    
    # Parameters of Opt
    weight_decay = 0.002
    batch_size = 16
    lr = 0.1
    T1 = 0  
    gamma = 10000 
    training_epoch = 100
    n_train = training_epoch * batch_size
    # Initialize the model, optimizer and loss function
    model = Transformer(d=d, m=m, sigma_0=sigma_0, sigma_1=sigma_1).to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    main()