import os
import math
import time
import torch
import random
from datetime import datetime
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import (
    GPT2Tokenizer,
    Qwen2Config,
    Qwen2ForCausalLM,
    Qwen2Tokenizer,
    GPT2Tokenizer,
    GPT2Config, 
    GPT2LMHeadModel,
    get_cosine_schedule_with_warmup,
)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from typing import Dict, List, Optional
import numpy as np
import wandb
from d_muon import D_Muon

def init_distributed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        dist.init_process_group(backend="nccl", init_method="env://")
        torch.cuda.set_device(local_rank)
    else:
        rank, world_size, local_rank = 0, 1, 0
    return rank, world_size, local_rank

def is_main_process(rank: int) -> bool:
    return rank == 0


def disable_all_biases(model, keep_ln_bias=True):
    n = 0
    for name, p in model.named_parameters():
        if name.endswith(".bias"):
            if keep_ln_bias and ("ln" in name.lower() or "layernorm" in name.lower()):
                continue
            with torch.no_grad():
                p.zero_()
            p.requires_grad = False
            n += 1
    return model

class MyDataset(Dataset):
    def __init__(self, dataset_name, dataset, tokenizer, split_name="train", max_length=512):
        self.dataset_name = dataset_name
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.texts = dataset["text"]
        self.max_length = max_length
        self.tokens = []
        self.split_name = split_name
        self._tokenize_texts()

    def _tokenize_texts(self):
        cache_name = getattr(self.tokenizer, "name_or_path", self.tokenizer.__class__.__name__)
        cache_file = f"{self.dataset_name}_{self.split_name}_{cache_name}.bin"
        if os.path.exists(cache_file):
            self.tokens = torch.load(cache_file)
        else:
            for text in tqdm(self.texts, desc=f"Tokenizing texts [{self.split_name}]"):
                encoded = self.tokenizer.encode(text, add_special_tokens=True)
                self.tokens.extend(encoded)
            torch.save(self.tokens, cache_file)

    def __len__(self):
        return len(self.tokens) // self.max_length

    def __getitem__(self, idx):
        start_idx = idx * (self.max_length)
        end_idx = start_idx + (self.max_length)
        token_slice = self.tokens[start_idx:end_idx]
        data = torch.tensor(token_slice, dtype=torch.long)
        return data

from torch.utils.data.distributed import DistributedSampler

def get_model_and_dataloader(model_name, dataset_name, hidden_size,
                             batch_size=16, val_ratio=0.1, seed=42,
                             rank: int = 0, world_size: int = 1):
    name2path = {"openwebtext-100k": "Elriggs/openwebtext-100k"}

    raw = load_dataset(name2path[dataset_name], trust_remote_code=True)
    if isinstance(raw, dict) or hasattr(raw, "keys"):  # DatasetDict
        if "train" in raw:
            base = raw["train"]
        else:
            first_split = next(iter(raw.keys()))
            base = raw[first_split]
    else:
        base = raw

    if isinstance(raw, dict) and "validation" in raw:
        train_split = raw["train"]
        val_split   = raw["validation"]
    else:
        splits = base.train_test_split(test_size=val_ratio, seed=seed, shuffle=True)
        train_split = splits["train"]
        val_split   = splits["test"]


    if model_name == "qwen":
        tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True)
    elif "gpt2" in model_name:
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
    else:
        raise ValueError(f"model {model_name} not supported")

    train_dataset = MyDataset(dataset_name, train_split, tokenizer, split_name="train")
    val_dataset   = MyDataset(dataset_name, val_split,   tokenizer, split_name="val")


    if world_size > 1:
        train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True,  drop_last=True)
        val_sampler   = DistributedSampler(val_dataset,   num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, drop_last=True)
        val_loader   = DataLoader(val_dataset,   batch_size=batch_size, sampler=val_sampler,   shuffle=False, drop_last=False)
    else:
        train_sampler = None
        val_sampler   = None
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  drop_last=True)
        val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, drop_last=False)


    if model_name == "qwen":
        config = Qwen2Config(
            attention_dropout=0.0, bos_token_id=151643, eos_token_id=151643, hidden_act="silu",
            hidden_size=hidden_size, initializer_range=0.02, intermediate_size=4864,
            max_position_embeddings=513, max_window_layers=12, model_type="qwen2",
            num_attention_heads=16, num_hidden_layers=12, num_key_value_heads=16,
            rms_norm_eps=1e-06, rope_theta=1000000.0, sliding_window=1024,
            tie_word_embeddings=True, torch_dtype="bfloat16", use_cache=True,
            use_mrope=False, use_sliding_window=False, vocab_size=151936,
        )
        model = Qwen2ForCausalLM(config)
    elif model_name == "gpt2":
        if hidden_size is None:
            config = GPT2Config.from_pretrained("gpt2")
        else:
            assert hidden_size % 64 == 0, "hidden_size should be a multiple of 64 for GPT-2"
            n_head = hidden_size // 64
            config = GPT2Config(
                n_embd=hidden_size, n_layer=12, n_head=n_head,
                n_positions=1024, n_ctx=1024, vocab_size=len(tokenizer)
            )
        model = GPT2LMHeadModel(config)
        model.config.pad_token_id = tokenizer.pad_token_id
        if model.config.eos_token_id is None and tokenizer.eos_token_id is not None:
            model.config.eos_token_id = tokenizer.eos_token_id
    elif model_name == "gpt2-medium":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        config = GPT2Config(
            n_layer=24,
            n_embd=1024,
            n_head=16,
            n_positions=1024, n_ctx=1024,
            vocab_size=len(tokenizer),
            activation_function="gelu",
            attn_pdrop=0.1, embd_pdrop=0.1, resid_pdrop=0.1,
        )
        model = GPT2LMHeadModel(config)
        model.config.pad_token_id = tokenizer.pad_token_id
        if model.config.eos_token_id is None and tokenizer.eos_token_id is not None:
            model.config.eos_token_id = tokenizer.eos_token_id

    return model, train_loader, val_loader, train_sampler, val_sampler



@torch.compile
def zeropower_via_newtonschulz5(G, steps):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(0) > G.size(1):
        X = X.T
    # Ensure spectral norm is at most 1
    X = X / (X.norm() + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.T
        B = (
            b * A + c * A @ A
        )
        X = a * X + B @ X

    if G.size(0) > G.size(1):
        X = X.T
    return X


@torch.no_grad()
def validate(model, val_loader, device, max_batches: Optional[int] = None, distributed: bool = False):
    was_training = model.training
    model.eval()
    total_loss = torch.tensor(0.0, device=device)
    total_tokens = torch.tensor(0.0, device=device)
    seen_batches = 0

    for b_idx, batch in enumerate(val_loader):
        if (max_batches is not None) and (b_idx >= max_batches):
            break
        batch = batch.to(device)
        out = model(input_ids=batch, labels=batch)
        loss = out.loss
        total_loss += loss.detach() * batch.numel()
        total_tokens += batch.numel()
        seen_batches += 1

    if distributed:
        dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_tokens, op=dist.ReduceOp.SUM)

    avg_loss = (total_loss / total_tokens).item() if total_tokens.item() > 0 else float("inf")
    ppl = math.exp(avg_loss) if avg_loss < 20 else float("inf")
    if was_training:
        model.train()
    return avg_loss, ppl, seen_batches



def get_optimizer(args, model, main_process):

    def pick(name_pred):
        return [(n, p) for n, p in model.named_parameters()
                if name_pred(n, p) and p.requires_grad and p.ndim >= 2]
    def pick_ln(include_ln_bias=False):
        items = []
        for n, p in model.named_parameters():
            if (".ln_" in n) or n.endswith(".ln_f.weight") or n.endswith(".ln_f.bias") or ("layernorm" in n.lower()):
                if not p.requires_grad:
                    continue
                if n.endswith(".bias") and (not include_ln_bias):
                    continue 
                items.append((n, p))
        return items

    embed_params = pick(lambda n, p: "wte" in n)
    pe_params    = pick(lambda n, p: "wpe" in n)
    head_params  = pick(lambda n, p: "lm_head" in n)
    ln_params    = pick_ln(include_ln_bias=False)   


    if any("attn.c_attn" in n for n, _ in model.named_parameters()):
        qkv_params = pick(lambda n, p: "attn.c_attn" in n)   
        o_params   = pick(lambda n, p: "attn.c_proj" in n)   
    else:
        qkv_params = pick(lambda n, p: ("q_proj" in n or "k_proj" in n or "v_proj" in n))
        o_params   = pick(lambda n, p: ("o_proj" in n or "c_proj" in n))

    mlp_params   = pick(lambda n, p: ".mlp." in n or "mlp." in n)

    embed_param = [p for _, p in embed_params]
    pe_param    = [p for _, p in pe_params]
    head_param    = [p for _, p in head_params]
    ln_param    = [p for _, p in ln_params]
    qkv_param   = [p for _, p in qkv_params]
    o_param     = [p for _, p in o_params]
    mlp_param   = [p for _, p in mlp_params]
    
    if args.optimizer == 'LANTON':
        from lanton import LANTON
        sign_params = []
        sign_params += embed_param        # e.g. token embeddings
        sign_params += head_param             # lm_head
        sign_params += pe_param            
        sign_params += ln_param 
        muon_params = []
        muon_params += qkv_param            # q_proj / k_proj
        muon_params += o_param            # v_proj / o_proj
        muon_params += mlp_param           # FFN 
        optimizer = LANTON(
            lr=args.max_lr,
            wd=args.weight_decay,
            muon_params=muon_params,
            sign_params=sign_params,
            adaptive_warmup_steps=args.warmup_steps,
            scale1=args.scale1,
            scale2=args.scale2,
        )
    else:
        assert 0, "optimizer not supported"
    return optimizer

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    # wandb args
    parser.add_argument("--wandb_log", action='store_true', help="Use Wandb Log.")
    parser.add_argument("--wandb_project", default= 'gpt-openwebtext', type=str, help="Wandb project.")
    parser.add_argument("--wandb_run_name", type=str, default=None, help="optional custom run name")
    parser.add_argument("--log_interval", type=int, default=10, help="wandb log every N steps")
    # model, dataset, optimizer, training hyperparameters
    parser.add_argument("--model", type=str, default="gpt2-medium")
    parser.add_argument("--optimizer", type=str, default="LANTON")
    parser.add_argument("--wd", type=float, default=0.1)
    parser.add_argument("--dataset", type=str, default="openwebtext-100k")
    parser.add_argument("--hidden_size", type=int, default=1024)
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training")
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--warmup_steps", type=int, default=100, help="Number of warmup steps before adaptive adjustment")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="Ratio of validation data")
    parser.add_argument("--val_interval", type=int, default=50, help="run validation every N steps")
    parser.add_argument("--save_interval", default=100, type=int, help="save checkpoint every N steps")
    parser.add_argument("--val_max_batches", type=int, default=100, help="cap validation batches per validation run")
    parser.add_argument("--max_lr", default=5e-3, type=float, help="max lr.")
    parser.add_argument("--beta1", default=0.9, type=float, help="beta1 in AdamW.")
    parser.add_argument("--beta2", default=0.95, type=float, help="beta2 in AdamW.")
    parser.add_argument("--weight_decay", default=0.1, type=float, help="weight decay.")
    # adaptive muon specific args
    parser.add_argument("--noise_momentum", default=0.9, type=float, help="noise momentum in adaptive muon.") 
    parser.add_argument("--scale1", default=300, type=float, help="learning rate scale for embed/head layers")
    parser.add_argument("--scale2", default=1.0, type=float, help="learning rate scale for embed/head layers")
    args = parser.parse_args()
    
    rank, world_size, local_rank = init_distributed()
    main_process = is_main_process(rank)


    # Create directories if they don't exist
    os.makedirs("ckpt", exist_ok=True)
    
    # Set random seeds for reproducibility
    os.environ["PYTHONHASHSEED"] = str(args.seed)

    # setup the random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # set the wandb name
    args.wandb_run_name = f"{args.optimizer}_{args.max_lr}_1k"
    
    # get model and dataloader
    model, train_loader, val_loader, train_sampler, val_sampler = get_model_and_dataloader(
        args.model, args.dataset, args.hidden_size, args.batch_size, args.val_ratio, args.seed, rank=rank, world_size=world_size
    )
    # disable all the bias 
    model = disable_all_biases(model, keep_ln_bias=False)
    run_name = args.wandb_run_name or f"{args.model}_{args.optimizer}_lr{args.max_lr}_{args.dataset}_{datetime.now().strftime('%m%d_%H%M%S')}"
    if main_process and args.wandb_log:
        wandb.init(
            project=args.wandb_project,
            name=run_name,
            mode="online",
            config=vars(args),
        )
    # build optimizer
    optimizer = get_optimizer(args, model, main_process)

    # Setup device
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # if use ddp
    if world_size > 1:
        model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
    print(f"Using device: {device}")

    model.train()
    
    # Setup learning rate scheduler
    total_steps = len(train_loader) * args.num_epochs
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=total_steps,
        num_cycles=0.5,
    )
    
    start_time = time.time()

    all_training_loss = []
    all_val_loss = []
    for epoch in range(args.num_epochs):
        print(f"Epoch {epoch + 1}/{args.num_epochs}")
        print(f"Total steps per epoch: {len(train_loader)}")
        if train_sampler is not None:
            train_sampler.set_epoch(epoch)
        if val_sampler is not None:
            val_sampler.set_epoch(epoch)
        for step, batch in enumerate(train_loader):
            global_step = epoch * len(train_loader) + (step + 1)
            batch = batch.to(device)
            input_ids = batch
            outputs = model(input_ids=input_ids, labels=input_ids)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            all_training_loss.append(loss.item())

            # ====== Validate  ======
            if step % args.val_interval == 0:
                val_loss, val_ppl, val_seen = validate(
                    model, val_loader, device, max_batches=args.val_max_batches
                )
                if main_process:
                    print(f"[VAL] Epoch {epoch} Step {step} | val_loss={val_loss:.4f} ppl={val_ppl:.2f} batches={val_seen}")


            # log the training stats
            if step % args.log_interval == 0 and main_process:
                last_lr = lr_scheduler.get_last_lr()[0] if hasattr(lr_scheduler, "get_last_lr") else optimizer.param_groups[0]["lr"]
                if step % args.val_interval == 0:
                    print(f"Step {step}: loss {loss.item():.4f}")
                    if args.wandb_log:
                        wandb.log({
                            "iter": global_step,
                            "train/loss": loss.item(),
                            "val/loss": val_loss,
                            "train/lr":last_lr
                        }, step=global_step)
                else:
                    print(f"Step {step}: loss {loss.item():.4f}")
                    if args.wandb_log:
                        wandb.log({
                            "iter": global_step,
                            "train/loss": loss.item(),
                            "train/lr":last_lr
                        }, step=global_step)
            
            # Save checkpoints
            if (main_process and (step % args.save_interval == 0)) or (main_process and step == len(train_loader) - 1):
                date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
                os.makedirs(f"ckpt/{args.dataset}", exist_ok=True)
                ckpt_path = f"ckpt/{args.dataset}/{args.optimizer}_lr{args.max_lr}_step{step}_{date_time}.pt"
                torch.save((model.module if isinstance(model, DDP) else model).state_dict(), ckpt_path)
                print(f"Checkpoint saved at epoch {epoch + 1}, step {step}")

    if main_process:              
        print("Training completed!")
        final_time = (time.time() - start_time) / 3600
        print(f"Total training time: {final_time:.2f} hours")
        date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        os.makedirs(f"ckpt/{args.dataset}", exist_ok=True)
        ckpt_path = f"ckpt/{args.dataset}/{args.optimizer}_lr{args.max_lr}_step{step}_{date_time}_final.pt"
        torch.save((model.module if isinstance(model, DDP) else model).state_dict(), ckpt_path)
        print(f"Checkpoint saved at epoch {epoch + 1}, step {step}")
        wandb.finish()
    if world_size > 1:
        dist.barrier()
        dist.destroy_process_group()