import os
from random import randint
from matplotlib.colors import TwoSlopeNorm
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from model import GPT, GPTConfig

SEQUENCES_PER_BATCH = 512
VOCAB_SIZE = 32
BLOCK_SIZE = 32
EMBED_DIM = 2048
NUM_PAIRS = 8
NUM_EPOCHS = 50001
NUM_BATCHES = 500
PLOT_EVERY = 20
SAVE_EVERY = 15

GRAD_ACC_STEPS = 1
REAL_BATCH_SIZE = SEQUENCES_PER_BATCH // GRAD_ACC_STEPS

OUTPUT_FOLDER = f"experiments/gpt/007"

img_path = f"{OUTPUT_FOLDER}/imgs"
weights_path = f"{OUTPUT_FOLDER}/weights"
losses_path = f"{OUTPUT_FOLDER}/losses"

import os
os.makedirs(img_path, exist_ok=True)
os.makedirs(weights_path, exist_ok=True)
os.makedirs(losses_path, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_data(vocab_size, block_size, num_pairs, batch_size):
    assert block_size >= 2 * num_pairs + 1, "Block size must be at least 2 * num_pairs + 1"

    x_list = []
    pos_list = []
    y_list = []
    for _ in range(batch_size):
        # Generate random pairs
        x = torch.randperm(vocab_size)[:num_pairs * 2 + 1]

        # Pick query
        q_ix = torch.randint(0, num_pairs, (1,)).item()
        x[-1] = x[q_ix * 2]
        y = x[q_ix * 2 + 1]

        # Generate random positions
        pos = (torch.arange(0, 2 * num_pairs + 1) + randint(0, block_size)) % block_size

        x_list.append(x)
        pos_list.append(pos)
        y_list.append(y)

    # Stack all tensors
    x = torch.stack(x_list, dim=0)
    pos = torch.stack(pos_list, dim=0)
    y = torch.stack(y_list, dim=0)

    # Move to device
    x = x.to(device)
    pos = pos.to(device)
    y = y.to(device)

    return x, pos, y



def plot_model(W1, W2, W3, filename=None):
    """
    Plots three weight matrices side by side with a shared color scale and colorbar,
    equal aspect ratio for all matrices, no ticks, and block lines every 16 cells.
    If `filename` is provided, saves the figure; otherwise shows it.
    """
    # Determine symmetric scale for color normalization
    scale = max(
        abs(W1.min()), abs(W1.max()),
        abs(W2.min()), abs(W2.max()),
        abs(W3.min()), abs(W3.max())
    ) + 1e-4
    norm = TwoSlopeNorm(vmin=-scale, vcenter=0, vmax=scale)

    # Compute width ratios so that image aspect (pixel aspect = 1) holds
    h1, w1 = W1.shape
    h2, w2 = W2.shape
    h3, w3 = W3.shape
    # width_ratios relate to subplot widths, but with aspect='equal', pixel shape is square
    width_ratios = [w1/h1, w2/h2, w3/h3, 0]

    height = 4
    fig, axs = plt.subplots(
        1, 4, figsize=(height * sum(width_ratios) + 2, height),
        gridspec_kw={'width_ratios': width_ratios[:4]},
        constrained_layout=True
    )
    axs[3].axis('off')  # Hide the last subplot

    mats = [(W1, "$W^{(1)}$"), (W2, "$W^{(2)}$"), (W3, "$W^{(3)}$")]
    for ax, (W, title) in zip(axs[0:3], mats):
        im = ax.imshow(W, cmap="bwr", interpolation="nearest", norm=norm, aspect='equal', rasterized=True)
        ax.set_title(title)
        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])
        # Add black lines every 16 cells to denote blocks
        rows, cols = W.shape
        for c in range(16, cols, 16):
            ax.axvline(c - 0.5, color='#ddd', linewidth=1, linestyle='--')
        for r in range(16, rows, 16):
            ax.axhline(r - 0.5, color='#ddd', linewidth=1, linestyle='--')

    # Add a single colorbar to the right spanning all subplots
    cbar = fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    
    if filename:
        plt.savefig(filename, dpi=200, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

data_path = f"data/{VOCAB_SIZE}_{BLOCK_SIZE}_{NUM_PAIRS}_{SEQUENCES_PER_BATCH * NUM_BATCHES}.pt"
if os.path.exists(data_path):
    print(f"Loading data from {data_path}")
    data = torch.load(data_path)
else:
    print("Generating data")
    data = generate_data(VOCAB_SIZE, BLOCK_SIZE, NUM_PAIRS, SEQUENCES_PER_BATCH * NUM_BATCHES)
    torch.save(data, data_path)

x, pos, y = data
x = x.to(device)
pos = pos.to(device)
y = y.to(device)
batches = [
    (x[i * REAL_BATCH_SIZE:(i + 1) * REAL_BATCH_SIZE],
    pos[i * REAL_BATCH_SIZE:(i + 1) * REAL_BATCH_SIZE],
    y[i * REAL_BATCH_SIZE:(i + 1) * REAL_BATCH_SIZE])
    for i in range(NUM_BATCHES)
]
print("Done")

print("Sample data:")
for i in range(10):
    print("x:", x[i])
    print("pos:", pos[i])
    print("y:", y[i])

config = GPTConfig(
    block_size = BLOCK_SIZE,
    vocab_size = VOCAB_SIZE, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer = 2,
    n_head = 1,
    n_embd = EMBED_DIM,
    dropout = 0.0,
    bias = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
)
model = GPT(config)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

losses = []
a, b, c, d = [], [], [], []

for epoch in range(NUM_EPOCHS):
    avg_loss = 0.0
    for i in range(GRAD_ACC_STEPS):
        x, pos, y = batches[(epoch * GRAD_ACC_STEPS + i) % len(batches)]

        y_pred, _ = model(x, pos)
        y_pred = y_pred[:, -1]
        loss = F.cross_entropy(y_pred, y)
        loss.backward()
        avg_loss += loss.item() / GRAD_ACC_STEPS
        
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch {epoch}, Loss: {avg_loss}")

    losses.append(avg_loss)

    x = (torch.cat([model.transformer.wte.weight, model.transformer.wpe.weight], dim=0))#.mul(ln1w.unsqueeze(1)) + ln1b.unsqueeze(1)
    q, k, v = model.transformer.h[0].attn.c_attn(x).split(model.transformer.h[0].attn.n_embd, dim=1)
    W1 = (
        q @ k.T / (k.shape[-1] ** 0.5)
    ).cpu().detach().numpy()

    x = (torch.cat([x, model.transformer.h[0].attn.c_proj(v)], dim=0))
    q, k, v = model.transformer.h[1].attn.c_attn(x).split(model.transformer.h[1].attn.n_embd, dim=1)
    W2 = (
        q @ k.T / (k.shape[-1] ** 0.5)
    ).cpu().detach().numpy()

    x = (torch.cat([x, model.transformer.h[1].attn.c_proj(v)], dim=0))
    W3 = (x @ model.lm_head.weight.T).cpu().detach().numpy()

    Wa = W1[BLOCK_SIZE:, BLOCK_SIZE:]
    Wa = np.concatenate([Wa[-1:], Wa[:-1]], axis=0)
    Wd = W1[BLOCK_SIZE:, BLOCK_SIZE:]
    Wc = W3[BLOCK_SIZE*4:BLOCK_SIZE*4 + VOCAB_SIZE]
    Wb = W2[:BLOCK_SIZE, BLOCK_SIZE * 2:BLOCK_SIZE * 2 + VOCAB_SIZE]

    a.append(np.trace(Wa))
    b.append(np.trace(Wb))
    c.append(np.trace(Wc))
    d.append(np.trace(Wd))

    if epoch % PLOT_EVERY == 0:
        plt.figure(figsize=(12,6))
        plt.plot(losses)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training Loss Over Time")
        plt.grid(True)
        plt.savefig(f"{img_path}/loss.png")
        plt.close()

        plot_model(W1, W2, W3, f"{img_path}/epoch_{epoch:04d}.pdf")
        open(f"{losses_path}/loss.csv", "w").write("\n".join([str(l) for l in losses]))

        open(f"{losses_path}/traces.csv", "w").write("\n".join([f"{i},{j},{k},{l}" for i,j,k,l in zip(a,b,c,d)]))

        # Save model weights
        torch.save(model.state_dict(), f"{weights_path}/model.pt")