
import os
import sys
import argparse
import torch.jit
from tqdm import tqdm
import torch
import torch.nn as nn
from datetime import datetime
from alternating_least_squares import *
from eval import *

import torch, os
from torch.utils.data import DataLoader
from distillation.kd import DistillConfig, Distiller

from utils.model_utils import *
from utils.data_utils import *
from component.svd_llama import SVD_LlamaAttention, SVD_LlamaMLP
# Enable TensorFloat32 (TF32) on Ampere GPUs for faster matrix multiplications
torch.backends.cuda.matmul.allow_tf32 = True 
torch.backends.cudnn.allow_tf32 = True

@torch.no_grad()
def SIMT(model_name, model, calib_loader, dev):
    """
    Spectral-Informed Metric Transformation (SIMT).
    Profiles each layer using calibration data to compute spectral factors 
    (activation covariance decomposition), which are later used in ALS compression.
    """
    layers = model.model.layers
    # Move embedding and norm to device for profiling
    model.model.embed_tokens = model.model.embed_tokens.to(dev) 
    model.model.norm = model.model.norm.to(dev)  
    layers[0] = layers[0].to(dev)

    # Prepare tensor to store intermediate activations
    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (len(calib_loader), model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    ) 
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}
    # Define a catcher module to intercept inputs from the first layer
    class Catcher(nn.Module): 
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            # Save input activations
            inps[cache['i']] = inp.cpu() 
            cache['i'] += 1
            # Cache attention masks and position ids
            if cache['attention_mask'] is None:
                cache['attention_mask'] = kwargs['attention_mask'].cpu()
                cache['position_ids'] = kwargs['position_ids'].cpu()
            else:
                cache['attention_mask'] = torch.cat((cache['attention_mask'], kwargs['attention_mask'].cpu()), dim=0) 
                cache['position_ids'] = torch.cat((cache['position_ids'], kwargs['position_ids'].cpu()), dim=0)
            raise ValueError # Stop forward propagation after capturing
    # Replace first layer temporarily with Catcher
    layers[0] = Catcher(layers[0])
    for batch in calib_loader:
        try:
            batch = {k: v.to(dev) for k, v in batch.items()}
            model(**batch) 
        except ValueError: 
            pass # Expected break
    layers[0] = layers[0].module  # Restore original layer
    layers[0] = layers[0].cpu() 
    # Move embedding and norm back to CPU
    model.model.embed_tokens = model.model.embed_tokens.cpu() 
    model.model.norm = model.model.norm.cpu() 
    torch.cuda.empty_cache() 
    # Prepare output buffer
    outs = torch.zeros_like(inps)
    attention_masks = cache['attention_mask']
    position_ids = cache['position_ids']
    spectral_factor = {} 
    # Profile each transformer layer
    for i in tqdm(range(len(layers))): 
        layer_profile = {}
        layer = layers[i].to(dev)
        subset = find_layers(layer)
        # Define forward hook to accumulate input covariances        
        def hook(module, input, output):
            inp = input[0].detach().float()
            if inp.dim() == 2:  
                inp = inp.unsqueeze(0)
            adds = torch.matmul(inp.transpose(1,2), inp) # XᵀX
            adds_sum = torch.sum(adds, dim=0)           
            module.scaling_diag_matrix += adds_sum   
            # Free memory   
            del inp, adds, adds_sum, output
            torch.cuda.empty_cache()
        # Register hooks on target submodules
        handles = []
        for name in subset:
            subset[name].scaling_diag_matrix = 0
            handles.append(subset[name].register_forward_hook(hook))
        # Run calibration samples through the layer
        for j in range(inps.shape[0]): 
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_masks[j].unsqueeze(0).to(dev), position_ids=position_ids[j].unsqueeze(0).to(dev))[0]
        # Remove hooks
        for h in handles:
            h.remove()
        layer = layer.cpu() 
        # Move accumulated matrices to CPU
        for name in subset:
            subset[name].scaling_diag_matrix = subset[name].scaling_diag_matrix.cpu() 
        torch.cuda.empty_cache() 
        # Eigen-decomposition of covariance to obtain spectral factor
        for name in subset:
            raw_scaling_diag_matrix = subset[name].scaling_diag_matrix.double().to(dev) 
            try:
                # Standard eigen-decomposition
                eigenvalues_dsa, eigenvectors_sad = torch.linalg.eigh(raw_scaling_diag_matrix/(256))  
                scaling_diag_matrix = torch.matmul(eigenvectors_sad,torch.diag(torch.sqrt(eigenvalues_dsa)))
            except Exception as e:
                # Fallback: add small shift to ensure positive definiteness
                eigenvalues = torch.linalg.eigvalsh(raw_scaling_diag_matrix)
                raw_scaling_diag_matrix += (- eigenvalues[0] + 1e-6) * torch.eye(raw_scaling_diag_matrix.shape[0]).to(dev)
                eigenvalues_dsa, eigenvectors_sad = torch.linalg.eigh(raw_scaling_diag_matrix/(256))  
                scaling_diag_matrix = torch.matmul(eigenvectors_sad,torch.diag(torch.sqrt(eigenvalues_dsa)))
                eigenvalues = None
                del eigenvalues
            # Save spectral factor for this submodule
            layer_profile[name] = scaling_diag_matrix.cpu() 
            scaling_diag_matrix = raw_scaling_diag_matrix = subset[name].raw_scaling_diag_matrix = eigenvalues_dsa = eigenvectors_sad = None
            # Release memory
            del scaling_diag_matrix, raw_scaling_diag_matrix, subset[name].raw_scaling_diag_matrix, eigenvalues_dsa, eigenvectors_sad
            torch.cuda.empty_cache()
        layers[i] = layer.cpu() 
        spectral_factor[i] = layer_profile 
        inps = outs   # Propagate to next layer
        torch.cuda.empty_cache()
    return spectral_factor

@torch.no_grad()
def Activation_aware_ALS(model_name, model, spectral_factor, ratio, dev, tau = 0.003, rho = 0.003, max_iter = 3):
    """
    Activation-aware Alternating Least Squares (ALS).
    Factorizes weight matrices using ALS, guided by SIMT spectral factors.
    Replaces original attention/MLP projections with low-rank approximations.
    """
    model.eval()
    
    layers = model.model.layers
    print("Start Activation-aware ALS after SIMT...")
    for i in tqdm(range(len(layers))):
        layer = layers[i]
        subset = find_layers(layer) 
        # Create new low-rank attention and MLP modules
        svd_attn = SVD_LlamaAttention(config=model.config, ratio=ratio)
        svd_mlp = SVD_LlamaMLP(hidden_size=layer.hidden_size, intermediate_size=model.config.intermediate_size, hidden_act=model.config.hidden_act, ratio=ratio)

        # Replace projections using ALS factorization
        for name in subset:
            # Compute low-rank factors U, V via ALS
            svd_u,svd_v = alternating_least_squares(subset,name,dev,spectral_factor,i,ratio,max_iter,tau,rho)
            # Assign factorized weights to corresponding modules
            if "q_proj" in name:
                svd_attn.q_u_proj.weight.data = svd_u
                svd_attn.q_v_proj.weight.data = svd_v
            elif "k_proj" in name:
                svd_attn.k_u_proj.weight.data = svd_u
                svd_attn.k_v_proj.weight.data = svd_v
            elif "v_proj" in name:
                svd_attn.v_u_proj.weight.data = svd_u
                svd_attn.v_v_proj.weight.data = svd_v
            elif "o_proj" in name:
                svd_attn.o_u_proj.weight.data = svd_u
                svd_attn.o_v_proj.weight.data = svd_v
                layer.self_attn =  svd_attn # Replace attention module
            elif "gate_proj" in name:
                svd_mlp.gate_u_proj.weight.data = svd_u
                svd_mlp.gate_v_proj.weight.data = svd_v
            elif "down_proj" in name:
                svd_mlp.down_u_proj.weight.data = svd_u
                svd_mlp.down_v_proj.weight.data = svd_v
            elif "up_proj" in name:
                svd_mlp.up_u_proj.weight.data = svd_u
                svd_mlp.up_v_proj.weight.data = svd_v
                layer.mlp = svd_mlp # Replace MLP module
        # Free resources
        del layer
        torch.cuda.empty_cache()


if __name__ == '__main__':
    # Initialize argument parser to handle command-line options
    parser = argparse.ArgumentParser()
    # -------------------- General settings --------------------
    parser.add_argument('--model', type=str, default='jeffwan/llama-7b-hf', help='Name or path of the pretrained LLaMA model to load (e.g., `jeffwan/llama-7b-hf`).')
    parser.add_argument('--model_path', type=str, default=None, help='Path to a locally saved compressed model or whitening statistics.')
    parser.add_argument('--ratio', type=float, default=0.6, help='Target compression ratio (e.g., 0.6 means 60% of parameters are removed).')
    parser.add_argument('--dataset', type=str, default='wikitext2', help='Dataset used to extract calibration samples.')
    parser.add_argument('--nsamples', type=int, default=256, help='Number of calibration samples to collect.')
    parser.add_argument('--save_path', type=str, default=None, help='Path to save the compressed model checkpoints.')
    parser.add_argument('--spectral_factor_path', type=str, default=None, help='Path to load precomputed profiling matrices (e.g., activation covariance).')
    parser.add_argument('--seed', type=int, default=3, help='Random seed for calibration data sampling.')
    parser.add_argument('--DEV', type=str, default="cuda", help='Device to use for computation (e.g., "cuda" or "cpu").')
    parser.add_argument('--model_seq_len', type=int, default=2048, help='Default maximum sequence length of the LLM.')
    parser.add_argument('--eval_batch_size', type=int, default=32, help='Batch size for evaluation during inference.')
    parser.add_argument('--step', type=int, default=4, help='Compression step index to execute.')

    # -------------------- ALS hyperparameters --------------------
    parser.add_argument('--tau', type=float, default=0.003, help='Ridge regularization coefficient in ALS.')
    parser.add_argument('--rho', type=float, default=0.003, help='Proximal regularization coefficient in ALS.')
    parser.add_argument('--iter', type=int, default=3, help='Number of ALS iterations to perform.')

    # -------------------- Distillation hyperparameters --------------------
    parser.add_argument('--student_ckpt', type=str, default=None, help='compressed model')
    parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate for updating low-rank matrices.')
    parser.add_argument('--epochs', type=int, default=50)

    parser.add_argument('--kd_loader_batch_size', type=int, default=3, help='Batch size for the knowledge distillation dataloader.')
    parser.add_argument('--kd_T', type=float, default=2.0, help='Temperature T for knowledge distillation.')
    parser.add_argument('--kd_w_qkv', type=float, default=1.0, help='Weight for Q/K/V feature distillation.')
    parser.add_argument('--kd_w_att', type=float, default=0.25, help='Weight for attention probability distillation.')
    parser.add_argument('--kd_w_head', type=float, default=0.5, help='Weight for attention head output distillation.')
    parser.add_argument('--kd_w_mha', type=float, default=1.0, help='Weight for multi-head attention output distillation.')
    parser.add_argument('--kd_w_ff1', type=float, default=0.5, help='Weight for feed-forward intermediate representation distillation.')
    parser.add_argument('--kd_w_ff2', type=float, default=1.0, help='Weight for feed-forward output representation distillation.')
    parser.add_argument('--kd_use_kl',action='store_true',default=True,help='Use KL divergence for KD loss (default: True, otherwise MSE).')
    parser.add_argument('--kd_stride', type=int, default=4, help='If kd_layers is not explicitly specified, select layers at intervals ''(0, stride, 2*stride, ...).')

    # Parse all arguments
    args = parser.parse_args()
    # Convert "removed ratio" convention to "retained ratio" convention
    args.ratio = 1- args.ratio

    # -------------------- Step 1: Compression with ALS --------------------
    if args.step == 1:
        # Load pretrained model and tokenizer
        model, tokenizer = get_model_from_huggingface_llama(model_id=args.model)
        model = model.eval()

        # ---------- SIMT profiling ----------
        if args.spectral_factor_path is None:
            # Collect calibration samples for spectral factor estimation
            cali_white_data = get_calib_train_data(args.dataset, tokenizer, args.nsamples, seqlen=args.model_seq_len)
            # Compute spectral-informed metric transformation
            spectral_factor = SIMT(args.model, model, cali_white_data, args.DEV)
            # Save profiling result if a path is given
            if args.save_path is not None:
                torch.save(spectral_factor, args.save_path + "/" + args.model.replace("/", "_").replace("-", "_") + '_profiling_'+ args.dataset + '_' + str(args.nsamples)  + '_' + str(args.seed)+ 'eigdecomposition.pt')
        else:
            # Load precomputed spectral factor
            spectral_factor = torch.load(args.spectral_factor_path) 
    
        # ---------- Activation-aware ALS compression ----------
        Activation_aware_ALS(args.model, model, spectral_factor, args.ratio, args.DEV, tau = args.tau, rho = args.rho, max_iter = args.iter)

        # Save compressed model checkpoint if path is provided
        if args.save_path is not None:
            torch.save({'model': model, 'tokenizer': tokenizer}, args.save_path + "/" + args.model.replace("/", "_").replace("-", "_") +'_ratio' + str(1 - args.ratio) + f'_tau{args.tau}_rho{args.rho}_iter{args.iter}.pt')   
    
    # -------------------- Step 2: UM-MOD --------------------
    elif args.step == 2:
        # Load teacher model (full-precision pretrained model)
        teacher, _ = get_model_from_huggingface_llama(model_id=args.model)
        teacher = teacher.to(args.DEV).eval()

        # Load student model from checkpoint
        ckpt = torch.load(args.student_ckpt, map_location="cpu")
        student = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
        student = student.to(args.DEV).train() 

        # Freeze all parameters by default
        student.requires_grad_(False) 
        trainable_names, trainable_params = [], [] 
        # Only enable gradient for low-rank matrices ("u_proj" and "v_proj")
        for n, p in student.named_parameters(): 
            if ("u_proj" in n) or ("v_proj" in n):
                p.requires_grad_(True)
                trainable_names.append(n); trainable_params.append(p)
        print(f"[Distill] Trainable params: {len(trainable_params)} tensors")

        # Enable gradient checkpointing to reduce memory usage 
        if hasattr(student, "gradient_checkpointing_enable"): 
            student.gradient_checkpointing_enable()
        # Disable cache to allow gradient computation across layers
        if hasattr(student, "config") and hasattr(student.config, "use_cache"):
            student.config.use_cache = False
        # Ensure embeddings allow gradient flow
        if hasattr(student, "enable_input_require_grads"):
            student.enable_input_require_grads() 
        # Load tokenizer from checkpoint if available
        tokenizer = ckpt["tokenizer"] if isinstance(ckpt, dict) and "tokenizer" in ckpt else None
        # Build calibration dataset for training
        trainset = get_calib_train_data(
            args.dataset, tokenizer, args.nsamples,
            seqlen=args.model_seq_len, seed=args.seed, batch_size=1
        ) 
        # Deterministic data loading
        g = torch.Generator()
        g.manual_seed(args.seed)  
        loader = DataLoader(trainset,batch_size=args.kd_loader_batch_size,shuffle=True,generator=g)
        # Select distillation layers (every kd_stride layers)
        kd_layers = list(range(0, student.config.num_hidden_layers, args.kd_stride))  
        kd_cfg = DistillConfig(layers=kd_layers,T=args.kd_T,w_qkv=args.kd_w_qkv,w_att=args.kd_w_att,w_head=args.kd_w_head,w_mha=args.kd_w_mha,w_ff1=args.kd_w_ff1,w_ff2=args.kd_w_ff2,use_kl=args.kd_use_kl)
        # Initialize distiller with teacher and student
        distiller = Distiller(teacher, student,args,  kd_cfg=kd_cfg) 
        # Learnable uncertainty-based weights (log variances for CE, KD, feature loss)
        loss_log_vars = torch.nn.Parameter(torch.zeros(3, device=args.DEV))
        # Optimizer: AdamW for low-rank params and uncertainty weights
        opt = torch.optim.AdamW(
            [
                {"params": trainable_params, "lr": args.lr, "betas": (0.9, 0.95), "weight_decay": 0.01},
                {"params": [loss_log_vars],   "lr": 1e-2,   "weight_decay": 0.0},  
            ]
        )

        student.train()
        use_autocast = args.DEV.startswith("cuda") 
        step = 0
        max_steps = getattr(args, "max_steps", 10_000)
        log_every = getattr(args, "log_every", 20)
        # ---------- Training loop ----------
        for epoch in range(args.epochs):
            for batch in loader:
                # Move batch to device
                batch = {k: v.to(args.DEV) for k, v in batch.items()} 

                # Forward pass (with autocast for efficiency if CUDA)
                if use_autocast:
                    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): 
                        losses = distiller(batch)
                else:
                    losses = distiller(batch)
                # Extract losses (default to zero if missing)
                ce   = losses.get("ce",   torch.tensor(0.0, device=args.DEV)).float()
                kd   = losses.get("kd",   torch.tensor(0.0, device=args.DEV)).float()
                feat = losses.get("feat", torch.tensor(0.0, device=args.DEV)).float()
                # Compute effective precision weights = exp(-log_vars)
                precision = torch.exp(-loss_log_vars)  
                # Uncertainty-weighted objective: Σ (precision_i * loss_i) + Σ log_vars
                loss = (precision[0]*ce + precision[1]*kd + precision[2]*feat) + loss_log_vars.sum()
                # Backpropagation
                opt.zero_grad(set_to_none=True) 
                loss.backward()  
                # Gradient sanity check
                bad_grad = False
                for p in trainable_params:
                    if p.grad is not None and not torch.isfinite(p.grad).all():
                        print("grad NaN/Inf at", getattr(p, "_param_name", "param"), p.shape, p.dtype)
                        bad_grad = True
                        break
                # Clip gradients to avoid explosion
                total_norm = torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
                if not torch.isfinite(total_norm):
                    print("total_grad_norm =", total_norm)
                # Update parameters
                opt.step() 
                # Clamp log_vars to a reasonable range
                with torch.no_grad():
                    loss_log_vars.data.clamp_(min=-6.0, max=6.0)

                # Logging
                step += 1
                if step % log_every == 0:
                    w_eff = precision.detach().cpu().tolist()
                    print(f"[{step}] total={loss.item():.4f} | ce={float(ce):.4f} | kd={float(kd):.4f} | feat={float(feat):.4f} | "
                        f"w_eff≈[{w_eff[0]:.3g},{w_eff[1]:.3g},{w_eff[2]:.3g}]")
            # Periodically save intermediate checkpoint
            if epoch %5 == 0:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                run_dir = os.path.join("runs", f"student_distilled_{timestamp}_ep{epoch:02d}")
                os.makedirs(run_dir, exist_ok=True)

                save_name = f"student_distilled.pt"
                save_path = os.path.join(run_dir, save_name)
                save_student_pt(student, tokenizer, save_path)
                print(f"[save] student checkpoint saved to: {save_path}")
        # Save final distilled student model
        save_student_checkpoint(student, tokenizer, args, distiller)
    elif args.step == 3:
        print(f"evaluating {args.model_path}...")
        # Load compressed/distilled model from local path
        model, tokenizer = get_model_from_local(args.model_path)
        model.eval()
        model = model.float()
        model = model.to(args.DEV)
        # Evaluate perplexity on language modeling tasks
        ppl_eval(model, tokenizer, model_seq_len=args.model_seq_len, batch_size=args.eval_batch_size, device=args.DEV)
        # Evaluate accuracy on downstream reasoning benchmarks
        accuracy_eval(model, tokenizer, model_seq_len=args.model_seq_len, batch_size=args.eval_batch_size, device=args.DEV)