import json

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import random
random.seed(0)

import torch.nn.functional as F
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from huggingface_hub import snapshot_download

################################
##### Activation functions #####
################################
def forward_with_cache_logits(model, inputs, module, no_grad=True, attention_temp=None, layers_id=None):
    # define a tensor with the size of our cached activations
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            if attention_temp == None:
                outputs = model(**inputs)
            else:
                outputs = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
    else:
        if attention_temp == None:
            outputs = model(**inputs)
        else:
            outputs = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
        
    hook_handle.remove()

    with torch.no_grad():
        log_probs = F.log_softmax(outputs.logits[..., :-1, :],dim=-1)
        expanded_input_ids = inputs['input_ids'][:, 1:].unsqueeze(-1).to(log_probs.device)
        token_log_probs = log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

    return cache[0],token_log_probs



def forward_with_cache(model, inputs, module, no_grad=True, attention_temp=None, layers_id=None):
    # define a tensor with the size of our cached activations
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            if attention_temp == None:
                _ = model(**inputs)
            else:
                _ = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
    else:
        if attention_temp == None:
            _ = model(**inputs)
        else:
            _ = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
        
    hook_handle.remove()

    return cache[0]


def forward_with_logits(model, inputs, module, no_grad=True, attention_temp=None, layers_id=None):

    if no_grad:
        with torch.no_grad():
            if attention_temp == None:
                outputs = model(**inputs)
            else:
                outputs = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
    else:
        if attention_temp == None:
            outputs = model(**inputs)
        else:
            outputs = model(**inputs,attention_temp=attention_temp, layers_id=layers_id)
        
    return F.log_softmax(outputs.logits[..., :-1, :],dim=-1)

# def forward_with_cache(model, inputs, module, no_grad=True):
#     # define a tensor with the size of our cached activations
#     cache = []
#     def hook(module, input, output):
#         if isinstance(output, tuple):
#             cache.append(output[0])
#         else:
#             cache.append(output)
#         return None 
    
#     hook_handle = module.register_forward_hook(hook)
    
#     if no_grad:
#         with torch.no_grad():
#             _ = model(**inputs)
#     else:
#         _ = model(**inputs)
        
#     hook_handle.remove()

#     return cache[0]
    
#######################################
##### Model and data loading code #####
#######################################


def get_params(model, layer_ids, param_ids, all=False):
    params = []
    if all:
        # collect all parameters in the model
        for name, p in model.named_parameters():
            params.append(p)
            # print(name)
        return params
    
    else:
        for layer_id in layer_ids:
            for i, (name, p) in enumerate(model.model.layers[layer_id].named_parameters()):
                # print(name, i)
                if i in param_ids:
                    params.append(p)
                    # print(name)
    return params




from baselines.att.my_llama import LlamaForCausalLM
from baselines.att.my_mistral import MistralForCausalLM
from baselines.att.my_mixtral import MixtralForCausalLM
from baselines.att.my_qwen2 import Qwen2ForCausalLM
# Step 2: Map class name (from config) to actual class object
custom_class_map = {
    "LlamaForCausalLM": LlamaForCausalLM,
    "MistralForCausalLM": MistralForCausalLM,
    "MixtralForCausalLM": MixtralForCausalLM,
    "Qwen2ForCausalLM": Qwen2ForCausalLM
}


def load_model(model_name_or_path):
    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
    arch = config.architectures[0]
    model_class = custom_class_map.get(arch, None)
    try:

        # model = AutoModelForCausalLM.from_pretrained(
        #     model_name_or_path,
        #     torch_dtype=torch_dtype,
        #     trust_remote_code=True,
        #     device_map="auto",
        # )

        model = model_class.from_pretrained(
            model_name_or_path,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
            device_map="auto",
        )
        print("Custom Model Loaded")
    except Exception as e:
        raise ValueError(
            f"Error {e} while fetching model using AutoModelForCausalLM.from_pretrained()."
        )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer



def load_model_half(model_name_or_path, n_layers):

    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    n_layers = n_layers+1

    assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
    token = os.getenv("HF_TOKEN")
    # 0) Get a local checkpoint dir for the base model
    ckpt_dir = snapshot_download(repo_id=model_name_or_path, token=token)

    # 1) Make a config with only the first N layers
    cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True, token=token)
    cfg.num_hidden_layers = n_layers

    arch = cfg.architectures[0]
    model_class = custom_class_map.get(arch, None)
    # 2) Build an empty N-layer model (no weights yet)
    try:
        with init_empty_weights():
            model = model_class._from_config(cfg)
        print("Custom Model Loaded")

    except Exception as e:
        raise ValueError(
            f"Error {e} while fetching model using AutoModelForCausalLM.from_pretrained()."
        )
    # 3) Stream weights for those N layers only, and shard across both GPUs
    no_split = ["LlamaDecoderLayer", "YiDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer"]  # safe superset
    model = load_checkpoint_and_dispatch(
        model,
        checkpoint=ckpt_dir,            # HF repo or local full checkpoint
        device_map="auto",              # spreads across all visible GPUs
        dtype=torch_dtype,
        no_split_module_classes=no_split,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False, token=os.getenv("HF_TOKEN"),
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer



# def load_model(model_name_or_path):
#     torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

#     model = AutoModelForCausalLM.from_pretrained(
#         model_name_or_path,
#         torch_dtype=torch_dtype,
#         trust_remote_code=True,
#         device_map="auto",
#     )

#     tokenizer = AutoTokenizer.from_pretrained(
#         model_name_or_path, 
#         trust_remote_code=True, 
#         use_fast=False,
#     )
#     tokenizer.pad_token_id = tokenizer.eos_token_id
#     tokenizer.padding_side = "left"
#     tokenizer.mask_token_id = tokenizer.eos_token_id
#     tokenizer.sep_token_id = tokenizer.eos_token_id
#     tokenizer.cls_token_id = tokenizer.eos_token_id

#     return model, tokenizer



# def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
#     def get_dataset(name):
#         data = []
#         if name == "wikitext":
#             from datasets import load_dataset
#             raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
#             for x in raw_data:
#                 if len(x['text']) > min_len:
#                     data.append(str(x['text']))
#         else:
#             for line in open(f"data/{name}.jsonl", "r"):
#                 if "bio-forget-corpus" in name:
#                     raw_text = json.loads(line)['text']
#                 else:
#                     raw_text = line
#                 if len(raw_text) > min_len:
#                     data.append(str(raw_text))
#         data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
#         return data

#     return (
#         [get_dataset(c) for c in forget_corpora],
#         [get_dataset(c) for c in retain_corpora]
#     )

from datasets import load_dataset
import os
def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
    def get_dataset(name):
        data = []
        if name == "wikitext":
            raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            for x in raw_data:
                if len(x['text']) > min_len:
                    data.append(str(x['text']))

        elif 'forget' in name:
            assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
            dataset = load_dataset(f"cais/wmdp-{name}", split="train", token=os.getenv("HF_TOKEN"))
            for line in dataset:
                if len(line['text']) > min_len:
                    data.append(str(line['text']))

        elif 'retain' in name:
            assert os.getenv("HF_TOKEN"), "HF_TOKEN is not set"
            dataset = load_dataset("cais/wmdp-corpora", name, split="train", token=os.getenv("HF_TOKEN"))
            for line in dataset:
                if len(line['text']) > min_len:
                    data.append(str(line['text']))
                    
        data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
        return data

    return (
        [get_dataset(c) for c in forget_corpora],
        [get_dataset(c) for c in retain_corpora]
    )