import argparse
import os
import torch 
import wandb
import copy
import itertools
import numpy as np
import time 
import math
from torch import nn
from tqdm import tqdm
from data_utils import get_lm_corpus
from torch.optim import AdamW, SGD
from transformers.models.qwen3.modeling_qwen3 import Qwen3Config
from model import Qwen3MoEForCausalLM
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, StepLR
from utils import set_seed, save_checkpoint, remove_old_checkpoints
from calculate import precompute_h0_values, conservation_log
def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor):
    input_ids =  data.T
    labels = target.T
    return input_ids, labels
def main(args: argparse.Namespace):
    torch.set_num_threads(8)
    torch.set_num_interop_threads(2)
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        name=f"{args.opt}-lr{args.lr}-warmup{args.warmup_lr}-hidden{args.hidden_size}-sheduler{args.lr_scheduler}-decay{args.weight_decay}"
        f"-warmstep{args.warmup_step}-momentum{args.momentum}-type{args.ffn_type}-router{args.router_act}-seed{args.seed}",
        save_code=True,
    )
    save_path = os.path.join(args.save_dir,wandb.run.name)
    wandb.config.update(vars(args))
    set_seed(args.seed)
    save_path = os.path.join(args.save_dir, wandb.run.name)
    os.makedirs(save_path, exist_ok=True)
    # Data
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 12
    tr_iter = corpus.get_iterator("train", args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator("valid", eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator("test", eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    # Model config
    config = Qwen3Config(
        vocab_size=ntokens,
        max_position_embeddings= args.tgt_len,
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layers,
        num_attention_heads=args.num_attention_heads,
        num_key_value_heads=args.num_key_value_heads,
        head_dim=args.hidden_size//args.num_attention_heads,
        intermediate_size=args.intermediate_size,
        hidden_act=args.hidden_act,
        router_act=args.router_act,
        num_experts=args.num_experts,
        num_experts_per_tok = args.num_experts_per_tok,
        ffn_type=args.ffn_type,
    ) 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = args.amp and device.type == "cuda"
    scaler = torch.amp.GradScaler(device=device.type,enabled=use_amp)
    model = Qwen3MoEForCausalLM(config).to(device)
    print(model)
    model.config.save_pretrained(save_path)
    num_train_steps = args.max_step
    # Optimizer and Scheduler
    start_factor = args.warmup_lr / args.lr
    if(args.opt == "sgd"):
        optimizer = SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    elif(args.opt == "adamw"):
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if(args.lr_scheduler == "linear"):
        scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_step)
    elif(args.lr_scheduler == "step"):
        scheduler = StepLR(optimizer, step_size = args.max_step//3, gamma=0.1)
    elif(args.lr_scheduler == "cosine"):
        scheduler = CosineAnnealingLR(optimizer, T_max=args.max_step, eta_min=args.eta_min)
    elif(args.lr_scheduler == "cosine_warm"):
        warmup_scheduler = LinearLR(optimizer, start_factor=start_factor, end_factor=1.0, total_iters=args.warmup_step)
        cosine_scheduler = CosineAnnealingLR(optimizer, T_max=args.max_step - args.warmup_step, eta_min=args.eta_min)
        scheduler = SequentialLR(optimizer,schedulers=[warmup_scheduler, cosine_scheduler],milestones=[args.warmup_step])
    start_model = copy.deepcopy(model)
    cached_h0 = precompute_h0_values(start_model, config)
    train_step = 0
    micro_step = 0
    best_val_loss = float("inf")
    log_start_time = time.time()
    eval_start_time = time.time()
    # Gradient accumulation
    accumulation_steps = args.gradient_accumulation_steps
    effective_batch_size = args.batch_size * accumulation_steps
    print(f"Micro batch size: {args.batch_size}")
    print(f"Gradient accumulation steps: {accumulation_steps}")
    print(f"Effective batch size: {effective_batch_size}")
    print("Starting training...\n")
    
    for epoch in itertools.count(start=1):
        print(f"Epoch {epoch}")
        model.train()
        train_iter = tr_iter.get_varlen_iter() if getattr(args, "varlen", False) else tr_iter
        train_losses_window = []
        
        optimizer.zero_grad(set_to_none=True)
        
        for batch_idx, (data, target, seq_len) in enumerate(tqdm(train_iter)):
            # print(batch_idx)
            if train_step >= args.max_step:
                break
            
            # Prepare batch
            x, y = prepare_lm_batch(data, target)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            
            # Forward pass with gradient accumulation
            with torch.amp.autocast(device_type=device.type, enabled=use_amp):
                out = model(input_ids=x, labels=y)
                loss = out.loss / accumulation_steps  # Scale loss for accumulation
            
            # Backward pass
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            
            # Accumulate loss for logging (use unscaled loss)
            train_losses_window.append((loss * accumulation_steps).detach().float().item())
            micro_step += 1
            
            # Update weights after accumulation_steps
            if micro_step % accumulation_steps == 0:
                if use_amp:
                    if getattr(args, "grad_clip", 0.0) and args.grad_clip > 0:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    if getattr(args, "grad_clip", 0.0) and args.grad_clip > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                    optimizer.step()
                
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                train_step += 1
                conservation_log(global_step=train_step,cached_h0=cached_h0,current_model=model,log_type="step",config=config,args=args)
                # Logging
                if train_step % args.logging_frequency == 0:
                    mean_train_loss = float(sum(train_losses_window) / max(1, len(train_losses_window)))
                    train_losses_window = []
                    train_ppl = math.exp(mean_train_loss) if mean_train_loss < 20 else float("inf")
                    train_bpc = mean_train_loss / math.log(2)
                    curr_lr = float(optimizer.param_groups[0]["lr"])
                    elapsed = time.time() - log_start_time
                    
                    if args.dataset in ["wt103", "lm1b"]:
                        print(
                            f"| epoch {epoch:3d} step {train_step:8d} | "
                            f"{micro_step:6d} micro-batches | lr {curr_lr:.3g} "
                            f"| ms/step {elapsed * 1000 / args.logging_frequency:5.2f} | "
                            f"loss {mean_train_loss:5.2f} | ppl {train_ppl:9.3f}"
                        )
                        wandb.log({"loss": mean_train_loss, "ppl": train_ppl, "learning_rate": curr_lr}, step=train_step)
                    elif args.dataset in ["enwik8", "text8","ptb"]:
                        print(
                            f"| epoch {epoch:3d} step {train_step:8d} | "
                            f"{micro_step:6d} micro-batches | lr {curr_lr:.3g} "
                            f"| ms/step {elapsed * 1000 / args.logging_frequency:5.2f} | "
                            f"loss {mean_train_loss:5.2f} | bpc {train_bpc:9.3f}"
                        )
                        wandb.log({"loss": mean_train_loss, "bpc": train_bpc, "learning_rate": curr_lr}, step=train_step)
                    log_start_time = time.time()
                
                # Evaluation
                if train_step % args.eval_frequency == 0:
                    model.eval()
                    val_losses = []
                    with torch.no_grad():
                        for eval_data, eval_target, _ in va_iter:
                            ex, ey = prepare_lm_batch(eval_data, eval_target)
                            ex = ex.to(device, non_blocking=True)
                            ey = ey.to(device, non_blocking=True)
                            with torch.amp.autocast(device_type=device.type, enabled=use_amp):
                                out = model(input_ids=ex, labels=ey)
                                vloss = out.loss
                            if not (math.isnan(vloss) or math.isinf(vloss)):
                                val_losses.append(vloss)                    
                    val_loss = float(sum(val_losses) / max(1, len(val_losses)))
                    val_ppl = math.exp(val_loss)
                    val_bpc = val_loss / math.log(2)
                    
                    print("-" * 100)
                    if args.dataset in ["wt103", "lm1b"]:
                        print(
                            f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                            f"time: {time.time() - eval_start_time:5.2f}s | "
                            f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
                        )
                        wandb.log({"eval_loss": val_loss, "eval_ppl": val_ppl}, step=train_step)
                    elif args.dataset in ["enwik8", "text8","ptb"]:
                        print(
                            f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                            f"time: {time.time() - eval_start_time:5.2f}s | "
                            f"valid loss {val_loss:5.2f} | valid bpc {val_bpc:9.3f}"
                        )
                        wandb.log({"eval_loss": val_loss, "eval_bpc": val_bpc}, step=train_step)
                    print("-" * 100)
                    
                    # Save best checkpoint
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        remove_old_checkpoints(save_path, "best_")
                        best_path = os.path.join(save_path, f"best_{train_step}.pt")
                        save_checkpoint(best_path, model, optimizer, scheduler, train_step, best_val_loss)
                        print(f"Best checkpoint saved at {save_path}")
                    
                    eval_start_time = time.time()
                    model.train()
                
                # Periodic checkpoint
                if train_step % args.save_frequency == 0:
                    remove_old_checkpoints(save_path, "last_")
                    last_path = os.path.join(save_path, f"last_{train_step}.pt")
                    save_checkpoint(last_path, model, optimizer, scheduler, train_step, best_val_loss)
                    print(f"Last checkpoint saved at {save_path}")
                
                if train_step >= args.max_step:
                    break
        
        if train_step >= args.max_step:
            print("-" * 100)
            print("End of training")
            break
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Model
    parser.add_argument("--ffn_type", type=str, default="mlp", choices=["mlp","smoe", "dmoe"], help="MLP type")
    parser.add_argument("--hidden_size", type=int, default=768, help="Model hidden size")
    parser.add_argument("--intermediate_size", type=int, default=3072, help="MLP intermediate size")
    parser.add_argument("--num_hidden_layers", type=int, default=12, help="Number of transformer layers")
    parser.add_argument("--num_attention_heads", type=int, default=12, help="Number of attention heads")
    parser.add_argument("--num_key_value_heads", type=int, default=12, help="KV heads (GQA/MQA).")
    parser.add_argument("--num_experts",  type=int, default=12, help="Number of Experts")
    parser.add_argument("--hidden_act", type=str, default="gelu", help="Activation function")
    parser.add_argument("--router_act", type=str, default="softmax", choices=["softmax", "nsigmoid"], help="Router activation")
    parser.add_argument("--num_experts_per_tok",type=int,default=2,help="Number of experts selected per token (top-k routing).")
    # Data / train
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--data-path", type=str, default="")
    parser.add_argument("--dataset", type=str, default="wt103", choices=["wt103", "lm1b", "enwik8", "text8","ptb"])
    parser.add_argument("--max_step", type=int, default=500000)
    parser.add_argument("--warmup_step", type=int, default=2000)
    parser.add_argument("--batch-size", type=int, default=96)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of gradient accumulation steps")
    parser.add_argument("--tgt_len", type=int, default=256)
    parser.add_argument("--eval_tgt_len", type=int, default=256)
    parser.add_argument("--ext_len", type=int, default=0)
    parser.add_argument("--mem_len", type=int, default=0)
    # Optim
    parser.add_argument("--lr-scheduler", type=str, default="cosine", choices=["linear","step", "cosine", "plateau","cosine_warm"])
    parser.add_argument('--lr', type=float, default=0.75,help='initial learning rate (0.00025|5 for adam|sgd)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--opt', default='sgd', type=str,choices=['adamw', 'sgd', 'adagrad'],help='optimizer to use.')
    parser.add_argument('--momentum', type=float, default=0.9,help='momentum for sgd')
    parser.add_argument("--weight_decay", type=float, default=0.01, help="weight deacy rate for lr scheduler")
    parser.add_argument('--eta_min', type=float, default=1.0e-8,help='min learning rate for cosine scheduler')
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument('--decay_rate', type=float, default=0.01,help='decay factor when ReduceLROnPlateau is used')
    parser.add_argument('--grad_clip', type=float, default=0.25,help='gradient clipping')
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    parser.add_argument("--amp", action="store_true", help="Use mixed precision on CUDA")
    # Logging / ckpt
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument("--wandb-entity", default=None)
    parser.add_argument("--wandb-group", default=None)
    parser.add_argument("--wandb-project", default=None)
    parser.add_argument("--wandb-run-dir", default=".wandb")
    parser.add_argument("--logging-frequency", type=int, default=200)
    parser.add_argument("--eval-frequency", type=int, default=4000)
    parser.add_argument("--save-frequency", type=int, default=4000)
    parser.add_argument("--model-save-dir", type=str, default="artifacts/")
    parser.add_argument("--restore-checkpoint-path", type=str, default=None)
    main(parser.parse_args())