import os
from matplotlib import pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import torch
import torch.nn as nn
import torch.nn.functional as F


OUTPUT_FOLDER = f"experiments/"

SEQUENCES_PER_BATCH = 512
NUM_PAIRS = 8
X_DIM = 16
P_DIM = 16
NUM_QUERIES = 1
NUM_STEPS = 500000
NUM_BATCHES = 1000
PLOT_EVERY = 250
SAVE_EVERY = 15
LEARNING_RATE = 1


class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, num_pairs, num_queries):
        super(Transformer, self).__init__()

        self.num_pairs = num_pairs
        self.num_queries = num_queries

        self.W1 = nn.Parameter(torch.zeros(input_dim, input_dim))
        self.W2 = nn.Parameter(torch.zeros(2 * input_dim, 2 * input_dim))
        self.WO = nn.Parameter(torch.zeros(4 * input_dim, output_dim))

        T = num_pairs * 2 + num_queries
        mask = torch.tril(torch.ones(T, T), diagonal=0)
        
        # Create a mask where queries cannot attend to each other
        for i in range(num_pairs * 2, num_pairs * 2 + num_queries):
            for j in range(num_pairs * 2, i):
                mask[i, j] = 0

        self.register_buffer("mask", mask)

    
    def attention(self, x, W):
        x = x.unsqueeze(1)
        y = F.scaled_dot_product_attention(x @ W, x, x, attn_mask=self.mask, is_causal=True)
        return y.squeeze(1)

    def forward(self, x):
        x1 = self.attention(x, self.W1)
        x1 = torch.cat([x, x1], dim=2)
        x2 = self.attention(x1, self.W2)
        x2 = torch.cat([x1, x2], dim=2)
        return x2 @ self.WO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def plot_model(W1, W2, W3, filename=None):
    """
    Plots three weight matrices side by side with a shared color scale and colorbar,
    equal aspect ratio for all matrices, no ticks, and block lines every 16 cells.
    If `filename` is provided, saves the figure; otherwise shows it.
    """
    # Determine symmetric scale for color normalization
    scale = max(
        abs(W1.min()), abs(W1.max()),
        abs(W2.min()), abs(W2.max()),
        abs(W3.min()), abs(W3.max())
    ) + 1e-4
    norm = TwoSlopeNorm(vmin=-scale, vcenter=0, vmax=scale)

    # Compute width ratios so that image aspect (pixel aspect = 1) holds
    h1, w1 = W1.shape
    h2, w2 = W2.shape
    h3, w3 = W3.shape
    # width_ratios relate to subplot widths, but with aspect='equal', pixel shape is square
    width_ratios = [w1/h1, w2/h2, w3/h3, 0]

    height = 4
    fig, axs = plt.subplots(
        1, 4, figsize=(height * sum(width_ratios) + 2, height),
        gridspec_kw={'width_ratios': width_ratios[:4]},
        constrained_layout=True
    )
    axs[3].axis('off')  # Hide the last subplot

    mats = [(W1, "$W^{(1)}$"), (W2, "$W^{(2)}$"), (W3, "$W^{(3)}$")]
    for ax, (W, title) in zip(axs[0:3], mats):
        im = ax.imshow(W, cmap="bwr", interpolation="nearest", norm=norm, aspect='equal', rasterized=True)
        ax.set_title(title)
        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])
        # Add black lines every 16 cells to denote blocks
        rows, cols = W.shape
        for c in range(16, cols, 16):
            ax.axvline(c - 0.5, color='#ddd', linewidth=1, linestyle='--')
        for r in range(16, rows, 16):
            ax.axhline(r - 0.5, color='#ddd', linewidth=1, linestyle='--')

    # Add a single colorbar to the right spanning all subplots
    cbar = fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    
    if filename:
        plt.savefig(filename, dpi=200, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

def generate_data(N, num_pairs, x_dim, p_dim, num_queries):
    x = torch.randn(N, num_pairs * 2, x_dim)

    q_ix = torch.randint(0, num_pairs, (N, num_queries))
    q_ix = q_ix.unsqueeze(2).expand(-1, -1, x_dim)

    q_x = torch.gather(x, 1, q_ix)
    q_y = torch.gather(x, 1, q_ix + num_pairs)

    p1 = torch.randn(N, num_pairs, p_dim)
    p2 = torch.cat([p1[:, :, p_dim // 2:], p1[:, :, :p_dim // 2]], dim=2)
    p = torch.cat([p1, p2], dim=1)

    p_q = torch.zeros(N, num_queries, p_dim)
    
    x = torch.cat([x, p], dim=2)
    x = x.view(N, 2, num_pairs, x_dim + p_dim)
    x = x.permute(0, 2, 1, 3).contiguous().view(N, num_pairs * 2, x_dim + p_dim)

    q_x = torch.cat([q_x, p_q], dim=2)
    x = torch.cat([x, q_x], dim=1)

    return x.to(device), q_y.to(device)

if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)

if __name__ == "__main__":
    config = {
        "SEQUENCES_PER_BATCH": SEQUENCES_PER_BATCH,
        "NUM_PAIRS": NUM_PAIRS,
        "X_DIM": X_DIM,
        "P_DIM": P_DIM,
        "NUM_QUERIES": NUM_QUERIES,
        "NM_STEPS": NUM_STEPS,
        "NUM_BATCHES": NUM_BATCHES,
        "PLOT_EVERY": PLOT_EVERY,
        "SAVE_EVERY": SAVE_EVERY,
        "LEARNING_RATE": LEARNING_RATE
    }
    # Save the config to a file
    with open(f"{OUTPUT_FOLDER}/config.txt", "w") as f:
        for key, value in config.items():
            f.write(f"{key} = {value}\n")

    print("Generating data...")
    batches = [
        generate_data(SEQUENCES_PER_BATCH, NUM_PAIRS, X_DIM, P_DIM, NUM_QUERIES) for _ in range(NUM_BATCHES)
    ]
    print("Done")

    model = Transformer(X_DIM + P_DIM, X_DIM, NUM_PAIRS, NUM_QUERIES).to(device)

    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=0, momentum=0)

    transfom = torch.zeros(X_DIM + P_DIM, X_DIM + P_DIM).to(device)

    def random_unitary(n):
        q, _ = torch.qr(torch.randn(n, n))
        return q.to(device)

    for step in range(NUM_STEPS):
        if step % PLOT_EVERY == 0:
            W1 = model.W1.cpu().detach().numpy()
            W2 = model.W2.cpu().detach().numpy()
            WO = model.WO.cpu().detach().numpy()
            plot_model(W1, W2, WO, f"{OUTPUT_FOLDER}/{step:04d}.pdf")

        x, y = batches[step % len(batches)]

        y_pred = model(x)[:, -NUM_QUERIES:]
        loss = F.mse_loss(y_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Step {step}, Loss: {loss.item()}")