import os
import queue
import threading
from matplotlib import pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset


OUTPUT_FOLDER = f"experiments"

SEQUENCES_PER_BATCH = 64
GRADIENT_ACCUMULATION = 1
MAX_NUM_PAIRS = 32
X_DIM = 256
P_DIM = 256
NUM_QUERIES = 1
NUM_STEPS = 2500001
NUM_BATCHES = 1000
PLOT_EVERY = 5
SAVE_EVERY = 15
LEARNING_RATE = 100
QUERY_LAST = True
MASK = True
ORTHOGONAL = True


class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, num_pairs, num_queries):
        super(Transformer, self).__init__()

        self.num_pairs = num_pairs
        self.num_queries = num_queries

        self.W1 = nn.Parameter(torch.zeros(input_dim, input_dim))
        self.W2 = nn.Parameter(torch.zeros(2 * input_dim, 2 * input_dim))
        self.WO = nn.Parameter(torch.zeros(4 * input_dim, output_dim))

        T = num_pairs * 2 + num_queries
        mask = torch.tril(torch.ones(T, T), diagonal=0)
        
        # Create a mask where queries cannot attend to each other
        for i in range(num_pairs * 2, num_pairs * 2 + num_queries):
            for j in range(num_pairs * 2, i):
                mask[i, j] = 0

        self.register_buffer("mask", mask)

    
    def attention(self, x, W):
        x = x.unsqueeze(1)
        y = F.scaled_dot_product_attention(x @ W, x, x, attn_mask=self.mask, is_causal=True)
        return y.squeeze(1)

    def forward(self, x):
        x1 = self.attention(x, self.W1)
        x1 = torch.cat([x, x1], dim=2)
        x2 = self.attention(x1, self.W2)
        x2 = torch.cat([x1, x2], dim=2)
        return x2 @ self.WO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def plot_model(W1, W2, W3, filename=None):
    """
    Plots three weight matrices side by side with a shared color scale and colorbar,
    equal aspect ratio for all matrices, no ticks, and block lines every 16 cells.
    If `filename` is provided, saves the figure; otherwise shows it.
    """
    # Determine symmetric scale for color normalization
    scale = max(
        abs(W1.min()), abs(W1.max()),
        abs(W2.min()), abs(W2.max()),
        abs(W3.min()), abs(W3.max())
    ) + 1e-4
    norm = TwoSlopeNorm(vmin=-scale, vcenter=0, vmax=scale)

    # Compute width ratios so that image aspect (pixel aspect = 1) holds
    h1, w1 = W1.shape
    h2, w2 = W2.shape
    h3, w3 = W3.shape
    # width_ratios relate to subplot widths, but with aspect='equal', pixel shape is square
    width_ratios = [w1/h1, w2/h2, w3/h3, 0]

    height = 4
    fig, axs = plt.subplots(
        1, 4, figsize=(height * sum(width_ratios) + 2, height),
        gridspec_kw={'width_ratios': width_ratios[:4]},
        constrained_layout=True
    )
    axs[3].axis('off')  # Hide the last subplot
    # fig.subplots_adjust(wspace=0.1)

    mats = [(W1, "$W^{(1)}$"), (W2, "$W^{(2)}$"), (W3, "$W^{(3)}$")]
    for ax, (W, title) in zip(axs[0:3], mats):
        im = ax.imshow(W, cmap="bwr", interpolation="nearest", norm=norm, aspect='equal', rasterized=True)
        ax.set_title(title)
        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])
        # Add black lines every 16 cells to denote blocks
        rows, cols = W.shape
        for c in range(16, cols, 16):
            ax.axvline(c - 0.5, color='#ddd', linewidth=1, linestyle='--')
        for r in range(16, rows, 16):
            ax.axhline(r - 0.5, color='#ddd', linewidth=1, linestyle='--')

    # Add a single colorbar to the right spanning all subplots
    cbar = fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    
    if filename:
        plt.savefig(filename, dpi=200, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

def make_orthogonal(x):
    return torch.linalg.qr(x.transpose(1,2), mode='reduced')[0].transpose(1,2)

def generate_data(N, num_pairs, x_dim, p_dim, num_queries):
    x = torch.randn(N, num_pairs * 2, x_dim)
    if ORTHOGONAL: # Make orthogonal
        x = make_orthogonal(x) * (X_DIM ** 0.5)
    q_ix = torch.randint(0, num_pairs, (N, num_queries))
    # Set q_ix = num_pairs - 1
    if QUERY_LAST:
        q_ix.fill_(num_pairs - 1)
    q_ix = q_ix.unsqueeze(2).expand(-1, -1, x_dim)

    q_x = torch.gather(x, 1, q_ix)
    q_y = torch.gather(x, 1, q_ix + num_pairs)


    p1 = torch.randn(N, num_pairs, p_dim)
    if ORTHOGONAL: # Make orthogonal
        p1 = make_orthogonal(p1) * (P_DIM ** 0.5)
    p2 = torch.cat([p1[:, :, p_dim // 2:], p1[:, :, :p_dim // 2]], dim=2)
    p = torch.cat([p1, p2], dim=1)

    p_q = torch.zeros(N, num_queries, p_dim)
    
    x = torch.cat([x, p], dim=2)
    x = x.view(N, 2, num_pairs, x_dim + p_dim)
    x = x.permute(0, 2, 1, 3).contiguous().view(N, num_pairs * 2, x_dim + p_dim)

    q_x = torch.cat([q_x, p_q], dim=2)
    x = torch.cat([x, q_x], dim=1)

    return x.to(device), q_y.to(device)

colors = ['#e05657', '#ff9a3f', '#2b93db', '#36c636']
import numpy as np
import matplotlib.pyplot as plt

# mask_param = 1 - mask_param
# mask_param[17] = 0
mask_param = np.ones(19)
if MASK:
    mask_param = np.zeros(19)
    mask_param[2] = mask_param[4] = mask_param[17] = 1.0

def ablate_weights(params):
    """
    Ablates the weights of the model by setting them to zero.
    """
    W1 = torch.zeros(X_DIM * 2, X_DIM * 2).to(device)
    W2 = torch.zeros(X_DIM * 4, X_DIM * 4).to(device)
    W3 = torch.zeros(X_DIM * 8, X_DIM).to(device)
    
    ix = 0
    for W, basis in zip([W1, W2, W3], [W1_basis, W2_basis, W3_basis]):
        for i, j, m in basis:
            if mask_param[ix] > 0:
                W[i * X_DIM:(i + 1) * X_DIM, j * X_DIM:(j + 1) * X_DIM] = params[ix] * m * X_DIM
            ix += 1

    return W1, W2, W3

class PrefetchDataset(IterableDataset):
    def __init__(self, num_pairs, prefetch_batches=10):
        self.prefetch_batches = prefetch_batches
        self.queue = queue.Queue(maxsize=self.prefetch_batches)
        self.stop_signal = object()
        self.num_pairs = num_pairs
        self.thread = None

    def data_generator(self):
        while True:
            data = generate_data(SEQUENCES_PER_BATCH, self.num_pairs, X_DIM, P_DIM, NUM_QUERIES)
            self.queue.put(data)

    def start(self):
        if self.thread is None:
            self.thread = threading.Thread(target=self.data_generator, daemon=True)
            self.thread.start()

    def __iter__(self):
        self.start()
        while True:
            yield self.queue.get()

def get_loader(num_pairs, prefetch_batches=10):
    dataset = PrefetchDataset(num_pairs=num_pairs, prefetch_batches=prefetch_batches)
    loader = DataLoader(
        dataset,
        batch_size=None,  # full batches are returned from `generate_data()`
        num_workers=0,    # threading used instead
    )
    return loader


if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)

# Init basis
I = torch.eye(X_DIM).to(device) / X_DIM
M = torch.zeros(X_DIM, X_DIM).to(device)
M[X_DIM//2:, 0:X_DIM//2] = M[0:X_DIM//2, X_DIM//2:] = torch.eye(X_DIM//2).to(device) / X_DIM
W1_basis = [(0, 0, I), (1, 1, I), (1, 1, M)]
W2_basis = [(0, 0, I), (0, 2, I), (1, 1, I), (1, 1, M), (1, 3, I), (1, 3, M), (2, 0, I), (2, 2, I), (3, 1, I), (3, 1, M), (3, 3, I), (3, 3, M)]
W3_basis = [(0, 0, I), (2, 0, I), (4, 0, I), (6, 0, I)]


def extract_basis(W1, W2, W3):
    """
    Extracts the basis from the weight matrices W1, W2, and W3.
    """
    params = []
    for W, basis in zip([W1, W2, W3], [W1_basis, W2_basis, W3_basis]):
        for i, j, m in basis:
            Wij = W[i * X_DIM:(i + 1) * X_DIM, j * X_DIM:(j + 1) * X_DIM]
            p = torch.sum(Wij * m).item()
            params.append(p)
    return params



if __name__ == "__main__":
    config = {
        "SEQUENCES_PER_BATCH": SEQUENCES_PER_BATCH,
        "MAX_NUM_PAIRS": MAX_NUM_PAIRS,
        "X_DIM": X_DIM,
        "P_DIM": P_DIM,
        "NUM_QUERIES": NUM_QUERIES,
        "NM_STEPS": NUM_STEPS,
        "NUM_BATCHES": NUM_BATCHES,
        "PLOT_EVERY": PLOT_EVERY,
        "SAVE_EVERY": SAVE_EVERY,
        "LEARNING_RATE": LEARNING_RATE,
        "GRADIENT_ACCUMULATION": GRADIENT_ACCUMULATION,
        "QUERY_LAST": QUERY_LAST,
        "MASK": MASK,
    }
    # Save the config to a file
    with open(f"{OUTPUT_FOLDER}/config.txt", "w") as f:
        for key, value in config.items():
            f.write(f"{key} = {value}\n")

def main(num_pairs):
    loader = get_loader(num_pairs)
    loader = iter(loader)

    model = Transformer(X_DIM + P_DIM, X_DIM, num_pairs, NUM_QUERIES).to(device)

    model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE / GRADIENT_ACCUMULATION, weight_decay=0, momentum=0)

    loss_emergence = None
    alpha_emergence = None # 2
    beta_emergence = None # 4
    gamma_emergence = None # 17

    p = []
    for step in range(NUM_STEPS):
        for _ in range(GRADIENT_ACCUMULATION):
            x, y = next(loader)

            y_pred = model(x)[:, -NUM_QUERIES:]
            loss = F.mse_loss(y_pred, y)
            loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        p = extract_basis(model.W1, model.W2, model.WO)
        W1, W2, W3 = ablate_weights(p)
        model.W1.data.copy_(W1)
        model.W2.data.copy_(W2)
        model.WO.data.copy_(W3)

        if loss < 0.5 and loss_emergence is None:
            loss_emergence = step
        if p[2] > 0.1 and alpha_emergence is None:
            alpha_emergence = step
        if p[4] > 0.1 and beta_emergence is None:
            beta_emergence = step
        if p[17] > 0.5 and gamma_emergence is None:
            gamma_emergence = step

        if loss_emergence is not None and alpha_emergence is not None and beta_emergence is not None and gamma_emergence is not None:
            return loss_emergence, alpha_emergence, beta_emergence, gamma_emergence

        if step % PLOT_EVERY == 0:
            print(f"Step {step}, Loss: {loss.item()}, alpha: {p[2]}, beta: {p[4]}, gamma: {p[17]}")

N = []
loss = []
alpha = []
beta = []
gamma = []

def plot_dependency():
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    fig.subplots_adjust(wspace=0.2)

    axs[0].plot(N, alpha, label="$\\alpha_3$", color=colors[0], marker="o", markersize=3)
    axs[0].plot(N, beta, label="$\\beta_2$", color=colors[1], marker="o", markersize=3)
    axs[0].plot(N, gamma, label="$\\gamma_3$", color=colors[2], marker="o", markersize=3)
    axs[0].set_xlabel("$N$")
    axs[0].set_ylabel("Emergence step")
    axs[0].legend(loc="upper left")
    
    axs[1].plot(N, loss, label="$\\mathcal{L}$", color=colors[3], marker="o", markersize=3)
    axs[1].set_xlabel("$N$ (log)")
    axs[1].set_ylabel("Emergence step (log)")
    axs[1].set_xscale("log")
    axs[1].set_yscale("log")
    axs[1].plot(N, np.array(N)**2 * 0.07, label="$N^2$", color="gray", linestyle="--")
    axs[1].legend(loc="upper left")

    fig.savefig(f"{OUTPUT_FOLDER}/emergence_by_N.pdf", dpi=400, bbox_inches="tight")
    plt.close(fig)

fout = open(f"{OUTPUT_FOLDER}/emergence_by_N.csv", "a")
fout.write("\n\nconfig = {" + str(config) + "}\n")
fout.write("num_pairs,loss,alpha,beta,gamma\n")

for num_pairs in [2, 3, 4, 5, 6, 8, 12, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 288, 320, 384, 448, 512]:
    print(f"Num pairs: {num_pairs}")
    l, a, b, g = main(num_pairs)
    loss.append(l)
    alpha.append(a)
    beta.append(b)
    gamma.append(g)
    N.append(num_pairs)
    fout.write(f"{num_pairs},{l},{a},{b},{g}\n")
    fout.flush()
    plot_dependency()