import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from safetensors import safe_open
from safetensors.torch import load_file
from peft import PeftModel, LoraConfig
from huggingface_hub import login
from accelerate import init_empty_weights, load_checkpoint_and_dispatch


DATASETS = ("NIFTY", "ACL18", "BigData22", "CIKM18")

def get_prompt_list(dataset : str, split = "val"):
    assert dataset in DATASETS
    assert split in ("train", "test", "val")
    
    # TEMP
    assert split in ("val")
    
    # Data paths
    if dataset == "NIFTY":
        path = "/home/nvinden/Work/teaching-w24-nvinden/rsfinkit/embeddings_nifty/datasets/nifty_test.jsonl"
    elif dataset == "ACL18":
        path = "/home/nvinden/Work/teaching-w24-nvinden/rsfinkit/embeddings_nifty/datasets/acl_test.jsonl"
    elif dataset == "BigData22":
        path = "/home/nvinden/Work/teaching-w24-nvinden/rsfinkit/embeddings_nifty/datasets/bigdata_test.jsonl"
    elif dataset == "CIKM18":
        path = "/home/nvinden/Work/teaching-w24-nvinden/rsfinkit/embeddings_nifty/datasets/cikm_test.jsonl"
        
    # Load data
    data_list = []

    # Open the JSONL file
    with open(path, 'r') as file:
        # Read each line in the file
        for line in file:
            # Parse the JSON data and append it to the list
            data_list.append(json.loads(line))
        
    return data_list

def load_model(model_name : str):
    if model_name in ["l3-chat", "l3-base"]:
        return load_base_model(model_name)
    elif model_name in ["l3-acl", "l3-bigdata", "l3-cikm", "l3-nifty"]:
        return load_LORA_model(model_name)
    elif model_name in ("l2-RLMF"):
        return load_L2_RLMF_model(model_name)
    elif model_name in ("l2-nifty", "l2-acl", "l2-bigdata", "l2-cikm"):
        return load_L2_LORA_model(model_name)
    elif model_name in ("l2-base", "l2-chat"):
        return load_L2_base_model(model_name)
        
    
    
# Model loading
def load_base_model(model_name : str):
    assert model_name in ["l3-chat", "l3-base"]
    
    if model_name == "l3-chat":
        base_model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
    elif model_name == "l3-base":
        base_model_path = "meta-llama/Meta-Llama-3-8B"
    elif model_name == "l2-base":
        base_model_path = "meta-llama/Meta-Llama-2-8B"
        
    # Load the tokenizer and base model
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    model = AutoModelForCausalLM.from_pretrained(base_model_path)

    # Set the model to evaluation mode
    model.eval()
    
    # Enable mixed precision if available
    if torch.cuda.is_available():
        model.half()
        model.to('cuda')

    return model, tokenizer
    

def load_LORA_model(model_name : str):
    model_paths = {
        "l3-acl":       "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama3/weights-hf-peft/adapters_lora/Meta-Llama-3-8B-Instruct+acl18",
        "l3-bigdata":   "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama3/weights-hf-peft/adapters_lora/Meta-Llama-3-8B-Instruct+bigdata22",
        "l3-cikm":      "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama3/weights-hf-peft/adapters_lora/Meta-Llama-3-8B-Instruct+cikm18",
        "l3-nifty":     "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama3/weights-hf-peft/adapters_lora/Meta-Llama-3-8B-Instruct+nifty",
    }
    
    assert model_name in list(model_paths.keys())
    
    base_model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
    
    # Load the tokenizer and base model
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    model = AutoModelForCausalLM.from_pretrained(base_model_path)
    
    model.load_adapter(model_paths[model_name])
    
    
    
    #adapter_model_path = model_paths[model_name]
    #adapter_config_path = adapter_model_path.replace("adapter_model.safetensors", "adapter_config.json")
    
    # Load the adapter weights from safetensor file
    #state_dict = load_file(adapter_model_path)

    # Load the config from file
    #with open(adapter_config_path, "r") as f:
    #    config = json.load(f)
        
    #lora_config = LoraConfig(**config)

    #model = PeftModel(model, lora_config)
    
    # Load the fine-tuned weights
    #with safe_open(adapter_model_path, framework="pt") as f:
    #    model.load_weights(f)

    # Set the model to evaluation mode
    model.eval()
    
    # Enable mixed precision if available
    if torch.cuda.is_available():
        model.half()
        model.to('cuda')

    return model, tokenizer

def load_L2_base_model(model_name : str):
    assert model_name in ("l2-base", "l2-chat")
    
    if model_name == "l2-base":
        base_model_path = "meta-llama/Llama-2-7b-hf"
    elif model_name == "l2-chat":
        base_model_path = "meta-llama/Llama-2-7b-chat-hf"
        
    # Load the tokenizer and base model
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    model = AutoModelForCausalLM.from_pretrained(base_model_path)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Enable mixed precision if available
    if torch.cuda.is_available():
        model.half()
        model.to('cuda')
        
    return model, tokenizer

def load_L2_LORA_model(model_name : str):
    model_paths = {
        "l2-nifty":     "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama2/weights-hf-peft/adapters_lora/llama-2-7b-chat-hf+nifty/adapter_model.safetensors",
        "l2-acl":       "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama2/weights-hf-peft/adapters_lora/llama-2-7b-chat-hf+acl18/adapter_model.safetensors",
        "l2-bigdata":   "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama2/weights-hf-peft/adapters_lora/llama-2-7b-chat-hf+bigdata22/adapter_model.safetensors",
        "l2-cikm":      "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama2/weights-hf-peft/adapters_lora/llama-2-7b-chat-hf+cikm18/adapter_model.safetensors",
    }
    
    assert model_name in list(model_paths.keys())
    
    base_model_path = "meta-llama/Llama-2-7b-chat-hf"
    
    # Load the tokenizer and base model
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    model = AutoModelForCausalLM.from_pretrained(base_model_path)
    
    adapter_model_path = model_paths[model_name]
    adapter_config_path = adapter_model_path.replace("adapter_model.safetensors", "adapter_config.json")
    
    # Load the adapter weights from safetensor file
    state_dict = load_file(adapter_model_path)
    
    # Load the config from file
    with open(adapter_config_path, "r") as f:
        config = json.load(f)
        
    lora_config = LoraConfig(**config)
    
    model = PeftModel(model, lora_config)
    model.load_state_dict(state_dict, strict=False)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Enable mixed precision if available
    if torch.cuda.is_available():
        model.half()
        model.to('cuda')
        
    return model, tokenizer

def load_L2_RLMF_model(model_name: str):
    model_paths = {
        "l2-RLMF": "/home/vision/projects/teaching-w24-nvinden/weights-hf-peft/llama-2-7b-chat-hf_nifty",
    }
    
    path = model_paths[model_name]
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(path)
    
    # Initialize an empty model with the appropriate architecture
    with init_empty_weights():
        config = AutoModelForCausalLM.from_pretrained(path).config
        model = AutoModelForCausalLM.from_config(config)
    
    # Load the model weights and dispatch layers to devices using accelerate
    model = load_checkpoint_and_dispatch(model, path, device_map="auto", no_split_module_classes=["LlamaDecoderLayer"])
    
    # Set the model to evaluation mode
    model.eval()
    
    return model, tokenizer
