import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer, LlamaForCausalLM
from datetime import datetime
import os


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    """
    Recursively find all submodules of a given type (e.g., Linear, Conv2d) in `module`.

    Args:
        module: The root module to search.
        layers: List of layer types to look for.
        name:   Prefix name used to construct hierarchical keys.

    Returns:
        dict mapping "layer_name" -> module reference.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def save_student_pt(student, tokenizer, save_path: str):
    """
    Save a student model (and tokenizer) as a PyTorch checkpoint.

    Steps:
      - Deep copy the model to avoid modifying the live instance.
      - Strip any forward/backward hooks that may cause serialization issues.
      - Move to CPU for portability.
      - Store {"model": shadow, "tokenizer": tokenizer} to `save_path`.

    Args:
        student:   nn.Module, the trained student model.
        tokenizer: Tokenizer object.
        save_path: Target path to save the checkpoint (.pt).
    """
    import os, gc, copy, torch
    from collections import OrderedDict

    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    was_training = student.training

    shadow = copy.deepcopy(student)
    # Remove hooks to avoid pickle errors
    for m in shadow.modules():
        if hasattr(m, "_forward_hooks") and isinstance(m._forward_hooks, dict):
            m._forward_hooks = OrderedDict()
        if hasattr(m, "_forward_pre_hooks") and isinstance(m._forward_pre_hooks, dict):
            m._forward_pre_hooks = OrderedDict()
        if hasattr(m, "_backward_hooks") and isinstance(m._backward_hooks, dict):
            m._backward_hooks = OrderedDict()
    # Clean extra gradient checkpointing hooks if exist
    for attr in ["_require_grads_hook", "_gradient_checkpointing_func"]:
        if hasattr(shadow, attr):
            try:
                setattr(shadow, attr, None)
            except Exception:
                pass

    shadow = shadow.to("cpu")
    shadow.train(was_training)   # restore training/eval state
 
    payload = {"model": shadow, "tokenizer": tokenizer}

    torch.save(payload, save_path)

    del shadow, payload
    gc.collect()

def get_model_from_huggingface_llama(model_id):
    """
    Load a pretrained LLaMA model and tokenizer from HuggingFace Hub.

    Args:
        model_id: model identifier (e.g., "jeffwan/llama-7b-hf").

    Returns:
        model, tokenizer pair.
    """
    cache_dir = "./"  
    
    tokenizer = LlamaTokenizer.from_pretrained(model_id,device_map="cpu",trust_remote_code=True,cache_dir=cache_dir)
    
    model = AutoModelForCausalLM.from_pretrained(model_id,device_map="cpu",torch_dtype=torch.bfloat16,trust_remote_code=True,cache_dir=cache_dir)
    
    model.seqlen = 2048 # manually add default seq length
    return model, tokenizer

def save_student_checkpoint(student, tokenizer, args, distiller=None):
    """
    Save a distilled student checkpoint along with run metadata.

    - Cleans hooks and gradient requirements from the model.
    - Converts to fp16 + CPU eval mode for compact storage.
    - Creates a new run directory with timestamp.
    - Saves model/tokenizer as student_distilled.pt.
    - Writes run arguments into args.txt.

    Args:
        student:   student model to save.
        tokenizer: tokenizer object.
        args:      argparse.Namespace containing run arguments.
        distiller: optional Distiller object (closed before saving).
    """
    if args.save_path is None:
        return
    # Clean up distiller if provided
    try:
        if distiller is not None:
            distiller.close()
    except Exception as e:
        print(f"[warn] distiller.close() failed: {e}")
    # Clean gradient hooks
    try:
        if hasattr(student, "disable_input_require_grads"):
            student.disable_input_require_grads()
        if hasattr(student, "_requires_grad_hook") and student._requires_grad_hook is not None:
            try:
                student._requires_grad_hook.remove()
            except Exception:
                pass
            student._requires_grad_hook = None
    except Exception as e:
        print(f"[warn] disable_input_require_grads failed: {e}")
    # Clear forward/backward hooks
    for m in student.modules():
        if hasattr(m, "_forward_pre_hooks"):
            m._forward_pre_hooks.clear()
        if hasattr(m, "_forward_hooks"):
            m._forward_hooks.clear()
        if hasattr(m, "_backward_hooks"):
            m._backward_hooks.clear()
    # Prepare model for saving
    student_to_save = student.half().to("cpu").eval()

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join("runs", f"student_distilled_{timestamp}")
    os.makedirs(run_dir, exist_ok=True)

    save_file = os.path.join(run_dir, "student_distilled.pt")
    txt_file  = os.path.join(run_dir, "args.txt")
    # Save model checkpoint
    torch.save({"model": student_to_save, "tokenizer": tokenizer}, save_file)
    # Save run metadata
    with open(txt_file, "w", encoding="utf-8") as f:
        f.write(f"[Distill Run Info]\n")
        f.write(f"Time: {timestamp}\n")
        f.write(f"Device: {args.DEV}\n")
        f.write(f"Student dtype: fp16\n")
        f.write(f"Student params: {sum(p.numel() for p in student_to_save.parameters())}\n\n")

        f.write("[Args]\n")
        for k, v in sorted(vars(args).items()):
            f.write(f"{k}: {v}\n")

    print(f"[Distill] Saved PT checkpoint to: {save_file}")
    print(f"[Distill] Saved args info to:     {txt_file}")
    return run_dir

def get_model_from_local(model_id):
    """
    Load a model checkpoint saved with save_student_pt() or save_student_checkpoint().

    Args:
        model_id: Path to checkpoint file (.pt).

    Returns:
        model, tokenizer pair.
    """
    pruned_dict = torch.load(model_id, weights_only=False, map_location='cpu')
    tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model']
    return model, tokenizer