import json
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import random
random.seed(0)
from datasets import load_dataset

from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from huggingface_hub import snapshot_download
################################
##### Activation functions #####
################################

def forward_with_cache(model, inputs, module, no_grad=True):
    # define a tensor with the size of our cached activations
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)
        
    hook_handle.remove()

    return cache[0]
    
#######################################
##### Model and data loading code #####
#######################################

def get_params(model, layer_ids, param_ids):
    params = []
    for layer_id in layer_ids:
        for i, (name, p) in enumerate(model.model.layers[layer_id].named_parameters()):
            if i in param_ids:
                params.append(p)
                print(name)
    return params

# def get_params(model, layer_ids, param_ids):
#     params = []
#     for layer_id in layer_ids:
#         for i, p in enumerate(model.model.layers[layer_id].parameters()):
#             if i in param_ids:
#                 params.append(p)
#     return params


def load_model(model_name_or_path):
    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer


def load_model_half(model_name_or_path, n_layers):

    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    n_layers = n_layers+1

    assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
    token = os.getenv("HF_TOKEN")
    # 0) Get a local checkpoint dir for the base model
    ckpt_dir = snapshot_download(repo_id=model_name_or_path, token=token)

    # 1) Make a config with only the first N layers
    cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True, token=token)
    cfg.num_hidden_layers = n_layers

    # 2) Build an empty N-layer model (no weights yet)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True)

    # 3) Stream weights for those N layers only, and shard across both GPUs
    no_split = ["LlamaDecoderLayer", "YiDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer"]  # safe superset
    model = load_checkpoint_and_dispatch(
        model,
        checkpoint=ckpt_dir,            # HF repo or local full checkpoint
        device_map="auto",              # spreads across all visible GPUs
        dtype=torch_dtype,
        no_split_module_classes=no_split,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False, token=os.getenv("HF_TOKEN"),
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer



def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
    def get_dataset(name):
        data = []
        if name == "wikitext":
            raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            for x in raw_data:
                if len(x['text']) > min_len:
                    data.append(str(x['text']))
        else:
            assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
            dataset = load_dataset(f"cais/wmdp-{name}", split="train", token=os.getenv("HF_TOKEN"))
            for line in dataset:
                if len(line['text']) > min_len:
                    data.append(str(line['text']))
        data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
        return data

    return (
        [get_dataset(c) for c in forget_corpora],
        [get_dataset(c) for c in retain_corpora]
    )