# train_epng.py 
import argparse
import json
import yaml
import os
import random,  datetime, re, math
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Iterable, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, TrainerState, TrainerControl
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, load_hf_dataset_train, get_dataset_converter
import wandb
from peft import LoraConfig, get_peft_model

os.environ["TOKENIZERS_PARALLELISM"] = "false"

EXPERT_TARGET_FILTER = ("experts.", "proj")
EXPERTS_PREFIX = "experts."
PROJ_TARGETS = ("up_proj", "gate_proj")
PROJ_REGEX = re.compile(r"experts\..*(?:up_proj|gate_proj)")

def _is_target_name(name: str) -> bool:
    return (EXPERTS_PREFIX in name) and any(t in name for t in PROJ_TARGETS)

TIME =  datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

def iter_target_linear_modules(model: nn.Module) -> Iterable[Tuple[str, nn.Module]]:
    for name, mod in model.named_modules():
        if isinstance(mod, nn.Linear) and _is_target_name(name):
            yield name, mod

@dataclass
class PruneGrowConfig:
    base_adapter: str = "base_lora"
    r_base: int = 8
    alpha_base: int = 16
    lora_dropout: float = 0.05
    prune_interval_steps: int = 500
    prune_interval_epoch: bool = False
    max_prune_grow: int = 5
    prune_frac: float = 0.2
    grow_frac: float = 0.2
    verbose: bool = True

@dataclass
class ExpertStats:
    select_freq: Dict[str, float] = field(default_factory=dict)

class RouterTracer:
    def __init__(self, layer_prefix: str = "model.layers"):
        self.layer_prefix = layer_prefix

    @torch.no_grad()
    def from_outputs(self, outputs, top_k: int = 8) -> Tuple[List[str], List[float]]:
        names, probs_all = [], []
        if not hasattr(outputs, "router_logits") or outputs.router_logits is None:
            return names, probs_all

        logits = outputs.router_logits 
        layers = list(logits) if isinstance(logits, (list, tuple)) else list(logits.unbind(0))
        heads = PROJ_TARGETS #("gate_proj", "up_proj")

        for L, lg in enumerate(layers):
            lg = lg.float()
            topk_vals, topk_idx = torch.topk(lg, k=min(top_k, lg.shape[-1]), dim=-1)
            lse = torch.logsumexp(lg, dim=-1, keepdim=True)
            probs = (topk_vals - lse).exp()
            idx_1d = topk_idx.reshape(-1).cpu().tolist()
            p_1d   = probs.reshape(-1).cpu().tolist()
            prefix = f"{self.layer_prefix}.{L}.mlp.experts."
            sel_names = [f"{prefix}{eid}.{h}" for eid in idx_1d for h in heads]
            sel_probs = [p for p in p_1d for _ in heads]
            names.extend(sel_names)
            probs_all.extend(sel_probs)
        return names, probs_all

def update_selection_freq(stats: ExpertStats, routed_module_names: List[str], gate_probs: Optional[List[float]] = None, weight: float = 1.0):
    if gate_probs is None:
        for name in routed_module_names:
            stats.select_freq["base_model.model."+name] = stats.select_freq.get(name, 0.0) + weight
    else:
        for name, p in zip(routed_module_names, gate_probs):
            stats.select_freq["base_model.model."+name] = stats.select_freq.get(name, 0.0) + float(p) * weight

def find_lora_layers(model: nn.Module) -> Iterable[Tuple[str, nn.Module]]:
    for name, mod in model.named_modules():
        if _is_target_name(name):
            yield name, mod

def _normalize(values: Dict[str, float]) -> Dict[str, float]:
    if not values:
        return {}
    v = torch.tensor(list(values.values()), dtype=torch.float32)
    vmin, vmax = float(v.min()), float(v.max())
    if math.isclose(vmax, vmin):
        return {k: 0.0 for k in values}
    return {k: (val - vmin) / (vmax - vmin) for k, val in values.items()}

def compute_importance_scores(stats: ExpertStats, cfg: PruneGrowConfig, mode: str = "freq") -> Dict[str, float]:
    f = _normalize(stats.select_freq)
    scores = {}
    keys = set(f) #| set(g) | set(w)
    for k in keys:
        s = f.get(k, 0.0)
        scores[k] = float(s)
    return scores

def select_prune_grow(scores: Dict[str, float], prune_frac: float, grow_frac: float, frozen: set[str]) -> Tuple[List[str], List[str]]:
    if not scores:
        return [], []

    items = [(n, s) for n, s in scores.items() if n not in frozen]
    if not items:
        return [], []

    items.sort(key=lambda x: x[1])
    n = len(items)
    k_prune = max(0, int(n * prune_frac))
    k_grow  = max(0, int(n * grow_frac))

    new_prunes = [name for name, _ in items[:k_prune]]
    prune_list = list(set(new_prunes) | set(frozen))

    if k_grow > 0:
        top_items = items[-k_grow:]
        grow_list = [name for name, _ in top_items if name not in prune_list]
    else:
        grow_list = []

    return prune_list, grow_list
@torch.no_grad()
def apply_prune(model: nn.Module, cfg: PruneGrowConfig, prune_module_names: List[str]):
    pruned = 0
    for name, module in find_lora_layers(model):
        if name not in prune_module_names:
            continue
        for pname, p in module.named_parameters(recurse=True):
            is_A = (f"lora_A.{cfg.base_adapter}" in pname)
            is_B = (f"lora_B.{cfg.base_adapter}" in pname)
            if not (is_A or is_B):
                continue 
            p.data.zero_()
            p.requires_grad_(False)
            pruned += 1
    if cfg.verbose:
        print(f"[PRUNE] disabled LoRA A/B pairs: {pruned//2}")

def _get_AB_for_adapter(mod: nn.Module, adapter: str):
    A = B = None
    for n, p in mod.named_parameters(recurse=True):
        if f"lora_A.{adapter}" in n: A = p
        if f"lora_B.{adapter}" in n: B = p
    return A, B

@torch.no_grad()
def _kaiming_init_all(A_w: torch.Tensor, B_w: torch.Tensor):
    nn.init.kaiming_uniform_(A_w, a=math.sqrt(5)); B_w.zero_()

@torch.no_grad()
def _orthogonalize_against_old(A_new, B_new, A_old, B_old):
    if A_old is not None and A_old.numel() > 0:
        Q, _ = torch.linalg.qr(A_old.T.to(torch.float32), mode='reduced')
        proj = (A_new.to(torch.float32) @ Q) @ Q.T
        A_new.sub_(proj.to(A_new.dtype))

    if B_old is not None and B_old.numel() > 0:
        Q, _ = torch.linalg.qr(B_old.to(torch.float32), mode='reduced')
        proj = Q @ (Q.T @ B_new.to(torch.float32))
        B_new.sub_(proj.to(B_new.dtype))
        
def grow_lora_rank_on_modules(model: nn.Module, adapter_name: str, grow_module_names: List[str], add_r: int, alpha_new: Optional[int] = None, dropout_new: float = 0.05, cycle:int=1):
    assert add_r > 0
    r_old = None
    for name, mod in model.named_modules():
        if name in grow_module_names:
            A_old, _ = _get_AB_for_adapter(mod, adapter_name)
            if A_old is not None: r_old = A_old.shape[0]; break
    if r_old is None:
        raise RuntimeError(f"Cannot find LoRA params for adapter '{adapter_name}'")
    if alpha_new is None:
        alpha_new = add_r

    grow_adapter = f"lora_grow{cycle}"
    grow_module_names_lora = [n.replace("base_model.model.", "") for n in grow_module_names]    
    lconf = LoraConfig(
        r=add_r, lora_alpha=alpha_new, lora_dropout=dropout_new,
        target_modules=grow_module_names_lora, bias="none", task_type="CAUSAL_LM"
    )
    model.add_adapter(grow_adapter, lconf)
    with torch.no_grad():
        for name, mod in model.named_modules():
            if name not in grow_module_names: continue
            A_new, B_new = _get_AB_for_adapter(mod, grow_adapter)
            if A_new is None or B_new is None: continue
            _kaiming_init_all(A_new, B_new)
            A_old, B_old = _get_AB_for_adapter(mod, adapter_name)
            _orthogonalize_against_old(A_new, B_new, A_old, B_old)

    adapters = [adapter_name] + [f"lora_grow{c}" for c in range(1, cycle+1)]
    model.base_model.set_adapter(adapters)
    return grow_adapter, r_old, r_old + add_r

def collect_new_adapter_params(model: nn.Module, adapter_name: str) -> List[torch.nn.Parameter]:
    params, seen = [], set()
    for n, p in model.named_parameters():
        if (f"lora_A.{adapter_name}" in n) or (f"lora_B.{adapter_name}" in n):
            if isinstance(p, torch.nn.Parameter) and p.requires_grad:
                pid = id(p)
                if pid not in seen:
                    seen.add(pid)
                    params.append(p)
    return params

def prune_and_grow_step(model: nn.Module, cfg: PruneGrowConfig, stats: ExpertStats,
                        cycle: int, score_mode: str, frozen: set[str]):
    scores = compute_importance_scores(stats, cfg, mode=score_mode)
    prune_list, grow_list = select_prune_grow(scores, cfg.prune_frac, cfg.grow_frac, frozen)

    if grow_list:
        _, r_old, r_new = grow_lora_rank_on_modules(
            model, cfg.base_adapter, grow_list, add_r=cfg.r_base,
            alpha_new=cfg.alpha_base, dropout_new=cfg.lora_dropout, cycle=cycle
        )
        if cfg.verbose:
            print(f"[GROW] LoRA rank {r_old} → {r_new} on {len(grow_list)} modules")

    if prune_list:
        apply_prune(model, cfg, prune_list)
        frozen.update(prune_list) 

    for d in (stats.select_freq, ):
        for k in d:
            d[k] *= 0.2
    return grow_list

class MoETrainer(Trainer):
    def __init__(self, *args, tracer: Optional[RouterTracer] = None, expert_stats: Optional[ExpertStats] = None, router_topk: int = 8, **kwargs):
        super().__init__(*args, **kwargs)
        self.tracer = tracer
        self.expert_stats = expert_stats
        self.router_topk = router_topk

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch: int | None = None, **kwargs):
        try:
            outputs = model(**inputs, output_router_logits=True)
        except TypeError:
            outputs = model(**inputs)
        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        if self.tracer is not None and self.expert_stats is not None:
            names, probs = self.tracer.from_outputs(outputs, top_k=self.router_topk)
            if names:
                update_selection_freq(self.expert_stats, names, probs)
        return (loss, outputs) if return_outputs else loss

class PruneGrowCallback(TrainerCallback):
    def __init__(self, cfg: PruneGrowConfig, stats: ExpertStats, score_mode: str = "freq", warmup_steps: int = 200, prune_interval_steps: int = 500, prune_interval_epoch: bool = False, use_wandb: bool = True):
        self.cfg = cfg
        self.stats = stats
        self.score_mode = score_mode
        self.warmup_steps = warmup_steps
        self.prune_interval_steps = prune_interval_steps
        self.prune_interval_epoch = prune_interval_epoch
        self.use_wandb = use_wandb
        self.frozen_modules: set[str] = set()

    def _grow_new_group(self, model, optimizer, lr_scheduler, cycle:int, step:int):
        prune_and_grow_step(model, self.cfg, self.stats, cycle=cycle,
                            score_mode=self.score_mode, frozen=self.frozen_modules)
        try:
            current_lrs = lr_scheduler.get_last_lr()
            current_lr = current_lrs[-1] if current_lrs else optimizer.param_groups[0]["lr"]
        except Exception:
            current_lr = optimizer.param_groups[0]["lr"]
        new_params = collect_new_adapter_params(model, f"lora_grow{cycle}")
        if new_params:
            optimizer.add_param_group({
                "params": new_params,
                "lr": current_lr,
                "betas": optimizer.param_groups[0].get("betas", (0.9, 0.95)),
                "eps": optimizer.param_groups[0].get("eps", 1e-8),
                "weight_decay": optimizer.param_groups[0].get("weight_decay", 0.0),
            })
            if hasattr(lr_scheduler, "base_lrs"):
                lr_scheduler.base_lrs.append(current_lr)
            optimizer.param_groups[-1]["initial_lr"] = current_lr
        tp = sum(p.numel() for p in model.parameters() if p.requires_grad)
        sf = self.stats.select_freq
        model.print_trainable_parameters()
        wandb.log({
            "expert/num_selected_modules": len(sf),
            "expert/avg_selection_freq": sum(sf.values())/len(sf) if sf else 0.0,
            "trainable_params": tp,
            "prune_grow/step": step,
            "cycle": cycle,
        }, step=step)


    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if self.prune_interval_epoch:
            return control
        step = state.global_step
        if step < self.warmup_steps or step -self.warmup_steps> (self.cfg.prune_interval_steps) * self.cfg.max_prune_grow: 
            return control
        if step % self.prune_interval_steps == 0:
            cycle = int((step - self.warmup_steps) // self.prune_interval_steps) + 1
            if self.cfg.max_prune_grow < 0 or cycle < self.cfg.max_prune_grow:
                self._grow_new_group(kwargs["model"], kwargs["optimizer"], kwargs["lr_scheduler"], cycle, step)
        return control

    def on_epoch_begin(self, args, state, control, **kwargs):
        if not self.prune_interval_epoch:
            return control
        epoch = int(state.epoch or 0)
        if epoch <= 0: 
            return control
        if self.cfg.max_prune_grow < 0 or epoch <= self.cfg.max_prune_grow:
            self._grow_new_group(kwargs["model"], kwargs["optimizer"], kwargs["lr_scheduler"], epoch, state.global_step)
        return control

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model_path", type=str, required=True)
    parser.add_argument("--train_dataset", type=str, required=True, help="HF dataset name or local jsonl")
    parser.add_argument("--dataset_subset", type=str, default=None)
    parser.add_argument("--dataset_split", type=str, default='train')
    parser.add_argument("--output_dir", type=str, default="results/checkpoints")
    parser.add_argument("--train_config", type=str, required=True)
    parser.add_argument("--gpu_id", type=int, default=2)

    parser.add_argument("--prune_interval_steps", type=int, default=50)
    parser.add_argument("--prune_interval_epoch", action="store_true")
    parser.add_argument("--max_prune_grow", type=int, default=5, help="-1 means only using initial LoRAs")
    parser.add_argument("--png_warmup_steps", type=int, default=100)
    parser.add_argument("--prune_frac", type=float, default=0.2)
    parser.add_argument("--grow_frac", type=float, default=0.2)
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--score_mode", type=str, default="freq", choices=["freq","grad","weight"])
    parser.add_argument("--router_topk", type=int, default=8)
    parser.add_argument("--layer_prefix", type=str, default="model.layers")

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    output_dir = args.output_dir + f"/{TIME}"
    base_model_path = args.base_model_path
    config = yaml.safe_load(open(args.train_config))
    os.makedirs(args.output_dir, exist_ok=True)

    seed = config['seed']
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    wandb.init(project="olmoe-experiments"+"-"+args.train_dataset, config=vars(args)|config, name=TIME)

    tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    if args.train_dataset.endswith('.jsonl') or os.path.exists(f"datasets/train/{args.train_dataset}.jsonl"):
        file_path = args.train_dataset if args.train_dataset.endswith('.jsonl') else f"datasets/train/{args.train_dataset}.jsonl"
        samples = [json.loads(i) for i in open(file_path).readlines()]
    elif args.train_dataset == "prefeval":
        file_path = "../prefeval/train_data.jsonl"
        dataset = [json.loads(i) for i in open(file_path).readlines()]
        converter = get_dataset_converter(args.train_dataset.split('/')[-1])
        samples = [converter(example) for example in dataset]
    else:

        dataset = load_hf_dataset_train(args.train_dataset, args.dataset_subset, args.dataset_split)
        converter = get_dataset_converter(args.train_dataset.split('/')[-1])
        samples = [converter(example) for example in dataset]
        print(f"Loaded {len(samples)} samples from {args.train_dataset}")
        if samples:
            print("Sample data structure:")
            print(json.dumps(samples[0], indent=2, ensure_ascii=False))


    buffer = []
    for instance in samples:
        if 'messages' not in instance:
            print(f"Warning: 'messages' key not found in instance: {instance}")
            continue
        input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
        buffer.append((input_ids, target_ids))
    print(f"Processed {len(buffer)} examples")

    seq_length = config['seq_length']
    random_concat_ratio = config['random_concat_ratio']
    concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
    dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
    train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])

    training_args = TrainingArguments(
        output_dir=output_dir,
        max_steps=config['steps'],
        per_device_train_batch_size=config['per_device_batch_size'],
        per_device_eval_batch_size=config['per_device_batch_size'],
        warmup_steps=config['warmup_steps'],
        weight_decay=config['weight_decay'],
        logging_dir=f"{output_dir}/logs",
        logging_steps=config['logging_steps'],
        save_steps=config['save_steps'],
        eval_strategy="steps",
        eval_steps=config['eval_steps'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
        bf16=True,
        report_to="wandb",
        lr_scheduler_type='constant',
        save_total_limit=5,
        learning_rate=config['learning_rate'],
        optim=config['optim'],
        adam_beta1=config['adam_beta1'],
        adam_beta2=config['adam_beta2'],
        gradient_checkpointing=config['gradient_checkpointing'],
        gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {},
    )

    def data_collator(data):
        input_ids = torch.stack([item[0] for item in data])
        labels = torch.stack([item[1] for item in data])
        return {"input_ids": input_ids, "labels": labels}

    model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        device_map=None,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        attn_implementation="sdpa"
    )

    target_names = [name for name, _ in iter_target_linear_modules(model)]
    lconf = LoraConfig(
        r=args.rank, lora_alpha=16, lora_dropout=0.05,
        target_modules=target_names, bias="none", task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lconf, adapter_name="base_lora")
    model.print_trainable_parameters()


    cfg = PruneGrowConfig(
        base_adapter="base_lora",
        r_base=args.rank, alpha_base=16, lora_dropout=0.05,
        prune_interval_steps=args.prune_interval_steps,
        prune_interval_epoch=args.prune_interval_epoch,
        max_prune_grow=args.max_prune_grow,
        prune_frac=args.prune_frac,
        grow_frac=args.grow_frac,
        verbose=True,
    )
    stats = ExpertStats()
    tracer = RouterTracer(layer_prefix=args.layer_prefix)

    trainer = MoETrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=data_collator,
        tracer=tracer,
        expert_stats=stats,
        router_topk=args.router_topk,
        callbacks=[PruneGrowCallback(
            cfg=cfg, stats=stats, score_mode=args.score_mode,
            warmup_steps=args.png_warmup_steps,
            prune_interval_steps=args.prune_interval_steps,
            prune_interval_epoch=args.prune_interval_epoch,
            use_wandb=True
        )]
    )

    if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    trainer.save_model(output_dir + "/esft_checkpoint")
    tokenizer.save_pretrained(output_dir + "/esft_last_checkpoint")
    wandb.finish()
    print("Training complete")

if __name__ == "__main__":
    main()