import os
import json
import math
import string
import random
from typing import List, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import wandb

from rollthedice.triangle_discovery.utils import (
  TransformerPolicy, MODEL_PATH, HIDDEN_DIM, DATA_ROOT, HASH_STR_LEN, TriangleTokenizer, dataset_dir, load_json, T_TRIANGLES, build_ar_sequences, MAX_LEN, parse_triangle_sequence, is_valid_triangle, save_model, validate_triangle_generation, init_wandb
)
from rollthedice.triangle_discovery.data_utils import generate_and_save_dataset

# ----------------------
# Hyperparameters
# ----------------------
BATCH_SIZE     = 64
LEARNING_RATE  = 3e-4
NUM_EPOCHS     = 25
DEBUG = False

# ----------------------
# Training loop
# ----------------------
def train_policy(policy, dataloader, num_epochs, learning_rate, device, pad_id, tokenizer=None, edges_data=None, wandb_run=None):
    calc_loss = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='none')
    optim = torch.optim.AdamW(policy.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.01)
    policy.to(device)

    print(f"\nStarting training for {num_epochs} epochs with bs {BATCH_SIZE}...")
    for epoch in range(1, num_epochs + 1):
        policy.train()
        total, n = 0.0, 0
        total_entropy = 0.0
        print(f"Epoch {epoch:03d}/{num_epochs}")
        for idx, batch in enumerate(dataloader):
            input_ids, labels, attn_mask, loss_mask = [t.to(device, non_blocking=True) for t in batch]
            optim.zero_grad(set_to_none=True)
            logits = policy(input_ids, attn_mask)            # [B, T, V]
            
            per_token_losses = calc_loss(logits.view(-1, logits.size(-1)), labels.view(-1))  # [B*T]
            per_token_losses = per_token_losses.view(labels.shape)  # [B, T]
            
            masked_losses = per_token_losses * loss_mask.float()
            loss = masked_losses.sum() / loss_mask.sum().clamp(min=1)
            
            # Calculate entropy
            with torch.no_grad():
                probs = torch.softmax(logits, dim=-1)  # [B, T, V]
                log_probs = torch.log_softmax(logits, dim=-1)  # [B, T, V]
                entropy = -(probs * log_probs).sum(dim=-1)  # [B, T]
                # Apply loss mask to entropy calculation
                masked_entropy = entropy * loss_mask.float()
                batch_entropy = masked_entropy.sum() / loss_mask.sum().clamp(min=1)
                total_entropy += batch_entropy.item()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optim.step()
            total += loss.item()
            print(f" Batch {idx+1}/{len(dataloader)} | Loss: {total / (n+1):.4f}", end='\r')
            n += 1
        
        avg_loss = total / max(1, n)
        avg_entropy = total_entropy / max(1, n)
        print(f"Epoch {epoch:03d}/{num_epochs} | Loss: {avg_loss:.4f} | Entropy: {avg_entropy:.4f}")
        
        # Log training metrics to wandb
        if wandb_run is not None:
            wandb.log({
                "train/epoch": epoch,
                "train/loss": avg_loss,
                "train/entropy": avg_entropy,
                "train/learning_rate": learning_rate,
            })
        
        # Periodic validation
        if tokenizer is not None and edges_data is not None:
            print(f"\n--- Validation at Epoch {epoch} ---")
            accuracy = validate_triangle_generation(policy, tokenizer, edges_data, device, num_samples=10, wandb_run=wandb_run, epoch=epoch)
            print("--- End Validation ---\n")
    
    print("Training complete.")



# ----------------------
# Main
# ----------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ensure dataset exists
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT, exist_ok=True)
        print(f"Saving data to root directory: {DATA_ROOT}")
        generate_and_save_dataset(num_graphs=10, fixed_hash_per_graph=True)

    ddir = dataset_dir(DATA_ROOT, HASH_STR_LEN, T_TRIANGLES)
    train_path = os.path.join(ddir, "train.json")
    vocab_path = os.path.join(ddir, "vocab.json")

    if not os.path.exists(train_path):
        raise FileNotFoundError(
            f"Could not find {train_path}. Make sure you ran the dataset export to {ddir}."
        )
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(
            f"Could not find {vocab_path}. The dataset must include the saved vocab."
        )

    train_items = load_json(train_path)  # list of dicts with 'input_text','target_text'
    saved_vocab = load_json(vocab_path)  # expected: list of entity tokens ["<a_0>", ...]
    tokenizer = TriangleTokenizer(**saved_vocab)

    print(f"Loaded {len(train_items)} items; tokenizer vocab_size={tokenizer.vocab_size}")

    # autoregressive dataset
    X, Y, M, L = build_ar_sequences(train_items, tokenizer, max_len=MAX_LEN)
    print(f"AR dataset: {X.shape[0]} sequences  |  length={X.shape[1]}  |  vocab={tokenizer.vocab_size}")

    # print sample sequences
    print(f"Sample input sequence: {X[0]}")
    print(f"Sample output sequence: {Y[0]}")
    print(f"Sample attention mask: {M[0]}")
    print(f"Sample loss mask: {L[0]}")

    dataset = TensorDataset(X, Y, M, L)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

    # --- transformer policy ---
    policy = TransformerPolicy(
        vocab_size=tokenizer.vocab_size,
        d_model=256,
        n_layer=4,
        n_head=8,
        dim_ff=1024,
        dropout=0.1,
        max_len=MAX_LEN,
        tie_weights=True,
    )

    # Load edges data for validation
    edges_data = []
    for i in range(10):  # Assuming 10 graphs
        edges_path = os.path.join(ddir, f"edges_{i}.json")
        if os.path.exists(edges_path):
            edges_data.append(load_json(edges_path))
    
    # Initialize wandb
    wandb_config = {
        "algo": "pretrain",
        "env": "TriangleDiscovery",
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "num_epochs": NUM_EPOCHS,
        "vocab_size": tokenizer.vocab_size,
        "d_model": 256,
        "n_layer": 4,
        "n_head": 8,
        "dim_ff": 1024,
        "dropout": 0.1,
        "max_len": MAX_LEN,
        "tie_weights": True,
        "weight_decay": 0.01,
        "betas": (0.9, 0.95),
        "max_grad_norm": 1.0,
    }
    
    if DEBUG:
        wandb_run = None
    else:
        wandb_run = init_wandb(run_name="pretrain_triangle_discovery", config=wandb_config)
    
    # train + save
    train_policy(policy, dataloader, NUM_EPOCHS, LEARNING_RATE, device, tokenizer.pad_id, tokenizer, edges_data, wandb_run)
    save_model(policy, tokenizer, MODEL_PATH)
    
    # Log model artifact to wandb
    if wandb_run is not None:
        try:
            artifact = wandb.Artifact("pretrain_triangle_discovery_model", type="model")
            artifact.add_file(MODEL_PATH)
            wandb.log_artifact(artifact)
        except Exception as e:
            print(f"[wandb] log_artifact failed: {e}")
        try:
            wandb.finish()
        except Exception:
            pass

if __name__ == "__main__":
    main()