import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast
from lion_pytorch import Lion
from zen import ZenGrad, ZenGrad_M
import math, pandas as pd, os
from tqdm import tqdm

# ========== Config ==========
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = 50257
block_size = 128
batch_size = 32
total_steps = 225000  # total training steps
max_grad_norm = 1.0
eval_every = 15000  # Steps between evaluations

# ========== Load GPT2 Tokenizer ==========
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# ========== Load Dataset from Local ==========
train_path = "/path/to/wikitext-2-raw/train.txt"
val_path   = "/path/to/wikitext-2-raw/validation.txt"

with open(train_path, "r", encoding="utf-8") as f:
    train_text = f.read()
with open(val_path, "r", encoding="utf-8") as f:
    val_text = f.read()

train_tokens = tokenizer(train_text)['input_ids']
val_tokens   = tokenizer(val_text)['input_ids']

# ========== Dataset ==========
class GPT2Dataset(Dataset):
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return max(0, len(self.tokens) - self.block_size - 1)

    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

train_dataset = GPT2Dataset(train_tokens, block_size)
val_dataset   = GPT2Dataset(val_tokens, block_size)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader    = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True)

# ========== Model Components ==========
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.key   = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj  = nn.Linear(n_embd, n_embd)
        self.attn_drop = nn.Dropout(0.1)
        self.resid_drop = nn.Dropout(0.1)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
                             .unsqueeze(0).unsqueeze(0))

    def forward(self, x):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            LogLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class SmallGPT(nn.Module):
    def __init__(self, vocab_size, n_embd=256, n_layer=4, n_head=4, block_size=128):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, n_embd)
        self.pos_embed   = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.blocks      = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f        = nn.LayerNorm(n_embd)
        self.head        = nn.Linear(n_embd, vocab_size)

    def forward(self, idx):
        B, T = idx.size()
        tok = self.token_embed(idx)
        pos = self.pos_embed[:, :T, :]
        x = tok + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

# ========== Evaluation ==========
def evaluate(loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                logits = model(x)
                B, T, C = logits.shape
                loss = criterion(logits.view(B * T, C), y.view(B * T))
                total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    return avg_loss, math.exp(avg_loss)

# ========== Training ==========
model = SmallGPT(vocab_size).to(device)
print(f"🧠 Total parameters: {sum(p.numel() for p in model.parameters()):,}")
model = torch.compile(model)

optimizer = ZenGrad(model.parameters())
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

metrics = {
    "step": [], "train_loss": [], "val_loss": [],
    "train_ppl": [], "val_ppl": []
}

step = 0
running_loss = 0.0

train_iter = iter(train_loader)

while step < total_steps:
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)

    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)

    with torch.cuda.amp.autocast():
        logits = model(x)
        B, T, C = logits.shape
        loss = criterion(logits.view(B * T, C), y.view(B * T))

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    scaler.step(optimizer)
    scaler.update()

    running_loss += loss.item()
    step += 1

    if step % eval_every == 0:
        avg_train_loss = running_loss / eval_every
        train_ppl = math.exp(avg_train_loss)
        val_loss, val_ppl = evaluate(val_loader)

        print(f"[Step {step}] Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"               Train PPL: {train_ppl:.2f}, Val PPL: {val_ppl:.2f}")

        metrics["step"].append(step)
        metrics["train_loss"].append(avg_train_loss)
        metrics["val_loss"].append(val_loss)
        metrics["train_ppl"].append(train_ppl)
        metrics["val_ppl"].append(val_ppl)

        running_loss = 0.0  # reset after logging

# ========== Save ==========
output_path = "/workspace/smallgpt_LogLU_ZenGrad_metrics.xlsx"
pd.DataFrame(metrics).to_excel(output_path, index=False)
print(f"✅ Metrics saved to {output_path}")
