"""
Training script for minimal transformer models on sequences.
"""
import os
import sys
import time
import math
import pickle
from contextlib import nullcontext
import argparse
import numpy as np
import torch

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import GPTConfig, GPT

def get_batch(data, batch_size, block_size, device):
    """Get a batch of training data"""
    data_size = block_size + 1
    ix = torch.randint((len(data) - data_size) // data_size, (batch_size,)) * data_size
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    
    if 'cuda' in str(device):
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

def estimate_loss(model, train_data, val_data, eval_iters, batch_size, block_size, device, ctx):
    """Estimate model loss on train/val sets"""
    out = {}
    model.eval()
    
    for split in ['train', 'val']:
        data = train_data if split == 'train' else val_data
        losses = torch.zeros(eval_iters)
        
        for k in range(eval_iters):
            X, Y = get_batch(data, batch_size, block_size, device)
            with ctx:
                _, loss = model(X, Y)
            losses[k] = loss.item()
        
        out[split] = losses.mean()
    
    model.train()
    return out

def get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr):
    """Learning rate scheduler with warmup and cosine decay"""
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

def main():
    parser = argparse.ArgumentParser(description='Train minimal transformer on sequences')
    parser.add_argument('--dataset', type=str, default='sequences')
    parser.add_argument('--n_layer', type=int, default=1)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--n_embd', type=int, default=120)
    parser.add_argument('--max_iters', type=int, default=10000)
    parser.add_argument('--min_value', type=int, default=0)
    parser.add_argument('--max_value', type=int, default=100)
    parser.add_argument('--is_sorted', type=str, default="True")
    parser.add_argument('--num_copies', type=int, default=1)
    parser.add_argument('--use_identity_embeddings', type=bool, default=False)
    parser.add_argument('--permutation_type', type=str, default="reversal")
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--learning_rate', type=float, default=5e-2)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = parser.parse_args()

    # Setup paths
    sequence_type = "sorted" if args.is_sorted == "True" else "unsorted"
    data_dir = os.path.join('data', args.dataset, sequence_type, 
                           f'{args.min_value}-{args.max_value}', args.permutation_type)

    # Load metadata
    with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
        meta = pickle.load(f)
    
    stoi, itos = meta['stoi'], meta['itos']
    block_size = meta['block_size']
    vocab_size = meta['vocab_size']

    # Create output directory
    config_suffix = "_identity" if args.use_identity_embeddings else ""
    config_suffix += "_no_mlp"
    config = f"{args.n_layer}_{args.n_head}_{args.n_embd}{config_suffix}"
    out_dir = f'out/{args.dataset}_{sequence_type}_{args.permutation_type}_{config}_{args.min_value}-{args.max_value}'
    os.makedirs(out_dir, exist_ok=True)

    # Training hyperparameters
    eval_interval = args.max_iters // 10
    log_interval = args.max_iters // 100
    eval_iters = args.max_iters // 10
    weight_decay = 0.0
    beta1, beta2 = 0.9, 0.95
    grad_clip = 1.0
    warmup_iters = args.max_iters // 20
    lr_decay_iters = args.max_iters
    min_lr = args.learning_rate / 10
    dropout = 0.0
    bias = False
    dtype = 'bfloat16'
    compile_model = True

    print(f"Training configuration:")
    print(f"- Model: {args.n_layer}L-{args.n_head}H-{args.n_embd}D")
    print(f"- Identity embeddings: {args.use_identity_embeddings}")
    print(f"- Learning rate: {args.learning_rate}")
    print(f"- Batch size: {args.batch_size}")
    print(f"- Max iterations: {args.max_iters}")

    # Setup device and precision
    torch.manual_seed(1337)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    device_type = 'cuda' if 'cuda' in args.device else 'cpu'
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

    # Load data
    if args.num_copies == 0:
        train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
        val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    else:
        train_data = np.memmap(os.path.join(data_dir, f'train_{args.num_copies}.bin'), dtype=np.uint16, mode='r')
        val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

    # Initialize model
    model_args = dict(
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        block_size=block_size,
        bias=bias,
        vocab_size=vocab_size,
        dropout=dropout,
        use_identity_embeddings=args.use_identity_embeddings
    )

    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    model.to(args.device)

    print(f"Model parameters: {model.get_num_params()/1e6:.2f}M")

    # Setup optimizer
    scaler = torch.amp.GradScaler('cuda', enabled=(dtype == 'float16'))
    optimizer = model.configure_optimizers(weight_decay, args.learning_rate, (beta1, beta2), device_type)

    # Compile model
    if compile_model:
        print("Compiling model...")
        unoptimized_model = model
        model = torch.compile(model)

    # Training loop
    model.train()
    iter_num = 0
    best_val_loss = 1e9
    t0 = time.time()

    print("Starting training...")

    while True:
        # Learning rate scheduling
        lr = get_lr(iter_num, warmup_iters, lr_decay_iters, args.learning_rate, min_lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Evaluation
        if iter_num % eval_interval == 0:
            losses = estimate_loss(model, train_data, val_data, eval_iters, 
                                 args.batch_size // 2, block_size, args.device, ctx)
            print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

            # Save checkpoint if improved
            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']
                if iter_num > 0:
                    raw_model = unoptimized_model if compile_model else model
                    checkpoint = {
                        'model': raw_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_args': model_args,
                        'iter_num': iter_num,
                        'best_val_loss': best_val_loss,
                    }
                    
                    if args.num_copies == 0:
                        ckpt_path = os.path.join(out_dir, f'{iter_num}_ckpt.pt')
                    else:
                        ckpt_path = os.path.join(out_dir, f'{iter_num}_ckpt_{args.num_copies}.pt')
                    
                    torch.save(checkpoint, ckpt_path)
                    print(f"Saved checkpoint: {ckpt_path}")

        if iter_num >= args.max_iters:
            break

        # Training step
        X, Y = get_batch(train_data, args.batch_size, block_size, args.device)
        
        with ctx:
            logits, loss = model(X, Y)

        # Backward pass
        scaler.scale(loss).backward()

        # Gradient clipping
        if grad_clip != 0.0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # Logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        
        if iter_num % log_interval == 0:
            lossf = loss.item()
            print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr:.2e}")

        iter_num += 1

    print("Training completed")

if __name__ == "__main__":
    main()