import os
from matplotlib import pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


OUTPUT_FOLDER = f"experiments"

SEQUENCES_PER_BATCH = 256
NUM_PAIRS = 8
X_DIM = 32
P_DIM = 32
NUM_QUERIES = 1
NUM_STEPS = 1501
NUM_BATCHES = 1000
PLOT_EVERY = 250
SAVE_EVERY = 15
LEARNING_RATE = 1


class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, num_pairs, num_queries):
        super(Transformer, self).__init__()

        self.num_pairs = num_pairs
        self.num_queries = num_queries

        self.W1 = nn.Parameter(torch.zeros(input_dim, input_dim))
        self.W2 = nn.Parameter(torch.zeros(2 * input_dim, 2 * input_dim))
        self.WO = nn.Parameter(torch.zeros(4 * input_dim, output_dim))

        T = num_pairs * 2 + num_queries
        mask = torch.tril(torch.ones(T, T), diagonal=0)
        
        # Create a mask where queries cannot attend to each other
        for i in range(num_pairs * 2, num_pairs * 2 + num_queries):
            for j in range(num_pairs * 2, i):
                mask[i, j] = 0

        self.register_buffer("mask", mask)

    
    def attention(self, x, W):
        x = x.unsqueeze(1)
        y = F.scaled_dot_product_attention(x @ W, x, x, attn_mask=self.mask, is_causal=True)
        return y.squeeze(1)

    def forward(self, x):
        x1 = self.attention(x, self.W1)
        x1 = torch.cat([x, x1], dim=2)
        x2 = self.attention(x1, self.W2)
        x2 = torch.cat([x1, x2], dim=2)
        return x2 @ self.WO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def plot_model(W1, W2, W3, filename=None):
    """
    Plots three weight matrices side by side with a shared color scale and colorbar,
    equal aspect ratio for all matrices, no ticks, and block lines every 16 cells.
    If `filename` is provided, saves the figure; otherwise shows it.
    """
    # Determine symmetric scale for color normalization
    scale = max(
        abs(W1.min()), abs(W1.max()),
        abs(W2.min()), abs(W2.max()),
        abs(W3.min()), abs(W3.max())
    ) + 1e-4
    norm = TwoSlopeNorm(vmin=-scale, vcenter=0, vmax=scale)

    # Compute width ratios so that image aspect (pixel aspect = 1) holds
    h1, w1 = W1.shape
    h2, w2 = W2.shape
    h3, w3 = W3.shape
    # width_ratios relate to subplot widths, but with aspect='equal', pixel shape is square
    width_ratios = [w1/h1, w2/h2, w3/h3, 0]

    height = 4
    fig, axs = plt.subplots(
        1, 4, figsize=(height * sum(width_ratios) + 2, height),
        gridspec_kw={'width_ratios': width_ratios[:4]},
        constrained_layout=True
    )
    axs[3].axis('off')  # Hide the last subplot
    # fig.subplots_adjust(wspace=0.1)

    mats = [(W1, "$W^{(1)}$"), (W2, "$W^{(2)}$"), (W3, "$W^{(3)}$")]
    for ax, (W, title) in zip(axs[0:3], mats):
        im = ax.imshow(W, cmap="bwr", interpolation="nearest", norm=norm, aspect='equal', rasterized=True)
        ax.set_title(title)
        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])
        # Add black lines every 16 cells to denote blocks
        rows, cols = W.shape
        for c in range(16, cols, 16):
            ax.axvline(c - 0.5, color='#ddd', linewidth=1, linestyle='--')
        for r in range(16, rows, 16):
            ax.axhline(r - 0.5, color='#ddd', linewidth=1, linestyle='--')

    # Add a single colorbar to the right spanning all subplots
    cbar = fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    
    if filename:
        plt.savefig(filename, dpi=200, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

def generate_data(N, num_pairs, x_dim, p_dim, num_queries):
    x = torch.randn(N, num_pairs * 2, x_dim)

    q_ix = torch.randint(0, num_pairs, (N, num_queries))
    q_ix = q_ix.unsqueeze(2).expand(-1, -1, x_dim)

    q_x = torch.gather(x, 1, q_ix)
    q_y = torch.gather(x, 1, q_ix + num_pairs)

    p1 = torch.randn(N, num_pairs, p_dim)
    p2 = torch.cat([p1[:, :, p_dim // 2:], p1[:, :, :p_dim // 2]], dim=2)
    p = torch.cat([p1, p2], dim=1)

    p_q = torch.zeros(N, num_queries, p_dim)
    
    x = torch.cat([x, p], dim=2)
    x = x.view(N, 2, num_pairs, x_dim + p_dim)
    x = x.permute(0, 2, 1, 3).contiguous().view(N, num_pairs * 2, x_dim + p_dim)

    q_x = torch.cat([q_x, p_q], dim=2)
    x = torch.cat([x, q_x], dim=1)

    return x.to(device), q_y.to(device)

colors = ['#e05657', '#ff9a3f', '#2b93db', '#36c636']
import numpy as np
import matplotlib.pyplot as plt

def plot_params(loss, params):
    """
    Plots the training loss and the parameters extracted from the weight matrices,
    side by side.
    """
    data = np.array(params).T
    fig, axs = plt.subplots(1, 2, figsize=(12, 2.5))
    # fig.subplots_adjust(wspace=0)

    # --- Left subplot: Loss ---
    # Average loss over the last 10 steps
    if len(loss) > 10:
        loss = np.convolve(loss, np.ones(10)/10, mode='valid')
        loss = np.concatenate([loss[0] * np.ones(9), loss])
    axs[1].plot(loss, label="Training Loss", color='#36c636')
    axs[1].set_xlabel("Step")
    axs[1].legend(loc="upper left")
    axs[1].set_ylim(-0.04, 1.4)

    # --- Right subplot: Parameters ---
    for i in range(len(data)):
        # parameter name
        if i < 3:
            name = f"$\\alpha_{i+1}$"
        elif i < 15:
            name = f"$\\beta_{{{i-2}}}$"
        else:
            name = f"$\\gamma_{i-14}$"

        # styling
        alpha = 1.0
        zorder = 100
        if i == 2:
            color = "#e05657"
        elif i == 4:
            color = "#ff9a3f"
        elif i == 17:
            color = "#2b93db"
        else:
            continue

        axs[0].plot(data[i], label=name, color=color, alpha=alpha, zorder=zorder)

    # axs[0].set_title("Extracted Parameters")
    axs[0].set_xlabel("Step")
    # add aggregated "other" entry
    axs[0].legend(loc="upper left")
    axs[0].set_ylim(-0.99, 2.9)

    # save and close
    plt.savefig(f"{OUTPUT_FOLDER}/H2_params_and_loss.pdf", dpi=400, bbox_inches="tight")
    plt.close(fig)

mask_param = np.zeros(19)
mask_param[2] = mask_param[4] = mask_param[17] = 1.0

def ablate_weights(params):
    """
    Ablates the weights of the model by setting them to zero.
    """
    W1 = torch.zeros(X_DIM * 2, X_DIM * 2).to(device)
    W2 = torch.zeros(X_DIM * 4, X_DIM * 4).to(device)
    W3 = torch.zeros(X_DIM * 8, X_DIM).to(device)
    
    ix = 0
    for W, basis in zip([W1, W2, W3], [W1_basis, W2_basis, W3_basis]):
        for i, j, m in basis:
            if mask_param[ix] > 0:
                W[i * X_DIM:(i + 1) * X_DIM, j * X_DIM:(j + 1) * X_DIM] = params[ix] * m * X_DIM
            ix += 1

    return W1, W2, W3


if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)

# Init basis
I = torch.eye(X_DIM).to(device) / X_DIM
M = torch.zeros(X_DIM, X_DIM).to(device)
M[X_DIM//2:, 0:X_DIM//2] = M[0:X_DIM//2, X_DIM//2:] = torch.eye(X_DIM//2).to(device) / X_DIM
W1_basis = [(0, 0, I), (1, 1, I), (1, 1, M)]
W2_basis = [(0, 0, I), (0, 2, I), (1, 1, I), (1, 1, M), (1, 3, I), (1, 3, M), (2, 0, I), (2, 2, I), (3, 1, I), (3, 1, M), (3, 3, I), (3, 3, M)]
W3_basis = [(0, 0, I), (2, 0, I), (4, 0, I), (6, 0, I)]


def extract_basis(W1, W2, W3):
    """
    Extracts the basis from the weight matrices W1, W2, and W3.
    """
    params = []
    for W, basis in zip([W1, W2, W3], [W1_basis, W2_basis, W3_basis]):
        for i, j, m in basis:
            Wij = W[i * X_DIM:(i + 1) * X_DIM, j * X_DIM:(j + 1) * X_DIM]
            p = torch.sum(Wij * m).item()
            params.append(p)
    return params



if __name__ == "__main__":
    config = {
        "SEQUENCES_PER_BATCH": SEQUENCES_PER_BATCH,
        "NUM_PAIRS": NUM_PAIRS,
        "X_DIM": X_DIM,
        "P_DIM": P_DIM,
        "NUM_QUERIES": NUM_QUERIES,
        "NM_STEPS": NUM_STEPS,
        "NUM_BATCHES": NUM_BATCHES,
        "PLOT_EVERY": PLOT_EVERY,
        "SAVE_EVERY": SAVE_EVERY,
        "LEARNING_RATE": LEARNING_RATE
    }
    # Save the config to a file
    with open(f"{OUTPUT_FOLDER}/config.txt", "w") as f:
        for key, value in config.items():
            f.write(f"{key} = {value}\n")

    print("Generating data...")
    batches = [
        generate_data(SEQUENCES_PER_BATCH, NUM_PAIRS, X_DIM, P_DIM, NUM_QUERIES) for _ in range(NUM_BATCHES)
    ]
    print("Done")

    model = Transformer(X_DIM + P_DIM, X_DIM, NUM_PAIRS, NUM_QUERIES).to(device)

    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=0, momentum=0)

    transfom = torch.zeros(X_DIM + P_DIM, X_DIM + P_DIM).to(device)

    def random_unitary(n):
        q, _ = torch.qr(torch.randn(n, n))
        return q.to(device)
    
    params = []
    losses = []

    for step in range(NUM_STEPS):
        if step % PLOT_EVERY == 0:
            plot_params(losses, params)

        x, y = batches[step % len(batches)]

        y_pred = model(x)[:, -NUM_QUERIES:]
        loss = F.mse_loss(y_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        p = extract_basis(model.W1, model.W2, model.WO)
        params.append(p)
        losses.append(loss.item())

        W1, W2, W3 = ablate_weights(p)
        model.W1.data.copy_(W1)
        model.W2.data.copy_(W2)
        model.WO.data.copy_(W3)

        print(f"Step {step}, Loss: {loss.item()}")