import json
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from utils.base_inference import BaseGenerator
from utils.prompt_template import *
from precompute_space import load_dataset_by_name


MODEL_PATH_DICT = {
    "meta-llama3": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama2": "meta-llama/Llama-2-7b-chat-hf",
    "llama2-13b": "meta-llama/Llama-2-13b-chat-hf",
    "llama2-70b": "meta-llama/Llama-2-70b-chat-hf",
    "llama3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
    "llama3.2-1b": "meta-llama/Llama-3.2-1B-Instruct",
    "llama3.2-3b": "meta-llama/Llama-3.2-3B-Instruct",
    "mistral": "mistralai/Mistral-7B-Instruct-v0.1",
    "opt-6.7b": "facebook/opt-6.7b",
    "opt-13b": "facebook/opt-13b",
    "opt-30b": "facebook/opt-30b",
    "falcon-7b": "tiiuae/falcon-7b-instruct",
    "falcon-40b": "tiiuae/falcon-40b-instruct",
    "falcon3-7b": "tiiuae/falcon-3-7b-instruct",
    "falcon3-10b": "tiiuae/falcon-3-10b-instruct",
    "gemma-7b": "google/gemma-7b-it",
    "gemma-2b": "google/gemma-2b-it",
    "gemma2-2b": "google/gemma-2-2b-it",
    "gemma2-27b": "google/gemma-2-27b",
    "gemma2-9b": "google/gemma-2-9b-it",
    "gemma3-1b": "google/gemma-3-1b-it",
    "gemma3-4b": "google/gemma-3-4b-it",
    "phi3-7b": "microsoft/Phi-3-small-8k-instruct",
    "phi3-3b": "microsoft/Phi-3-mini-4k-instruct",
    "phi3.5-3b": "microsoft/Phi-3.5-mini-instruct",
    "phi4-3b": "microsoft/Phi-4-mini-instruct",
    "glm": "THUDM/glm-4-9b-chat",
    "qwen-7b": "Qwen/Qwen2-7B-Instruct",
    "qwen-1.5b": "Qwen/Qwen2-1.5B-Instruct",
    "qwen2.5-0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
    "qwen2.5-1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
    "qwen2.5-3b": "Qwen/Qwen2.5-3B-Instruct",
    "qwen2.5-7b": "Qwen/Qwen2.5-7B-Instruct",
    "qwen2.5-14b": "Qwen/Qwen2.5-14B-Instruct",
    "qwen2.5-32b": "Qwen/Qwen2.5-32B-Instruct",
    "qwen2.5-72b": "Qwen/Qwen2.5-72B-Instruct",
    "qwen3-0.6b": "Qwen/Qwen3-0.6B",
    "qwen3-1.7b": "Qwen/Qwen3-1.7B",
    "qwen3-4b": "Qwen/Qwen3-4B",
    "qwen3-8b": "Qwen/Qwen3-8B",
    "mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
    "mistral-small": "mistralai/Mistral-Small-Instruct-2409",
    "mistral-large": "mistralai/Mistral-Large-Instruct-2407"
}

class VanillaLLM:
    def __init__(self, model_path, generation_config=None, prompt_dir="prompt_templates/", prompt_key="zero_shot", device="auto"):
        self.model = BaseGenerator(model_path=model_path, generation_config=generation_config, device=device)
        self.prompt_loader = PromptTemplateLoader(template_dir_path=prompt_dir)
        self.prompt_key = prompt_key
        
        
    def inference_on_dataset(self, questions, batch_size=16):
        prompts = [self.prompt_loader.construct_prompt(self.prompt_key, {"question": question}) for question in questions]
        prompts = [self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer) for prompt in prompts]
        responses = self.model.inference_on_data(prompts, batch_size=batch_size)
        return responses

    def generate(self, prompt, noisy_prompt=None):
        response = self.model.inference_one_sample(prompt=prompt, noisy_prompt=noisy_prompt)
        
        return response

class IclInference:
    def __init__(
        self,
        model_path,
        generation_config=None,
        prompt_dir="prompt_templates/",
        prompt_key="few_shot",
        device="auto",
        embedding_model="sentence-transformers/all-MiniLM-L6-v2",
        train_indices=None,
        train_data='truthful_qa'
    ):
        self.model = BaseGenerator(model_path=model_path, generation_config=generation_config, device=device)
        self.prompt_loader = PromptTemplateLoader(template_dir_path=prompt_dir)
        self.prompt_key = prompt_key

        # self.embedding_model = SentenceTransformer(embedding_model, device=device)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embedding_model = SentenceTransformer(embedding_model)
        self.embedding_model.to(device)

        # Load training data
        self.train_data = train_data
        self.training_data = load_dataset_by_name(train_data, num_train=100)

        # Choose subset of training examples
        self.train_indices = train_indices if train_indices is not None else list(range(100))

        # Precompute training question embeddings
        self.train_questions = [self.training_data[i]["question"] for i in self.train_indices]
        self.train_embeddings = self.embedding_model.encode(self.train_questions, convert_to_tensor=True)

        print(f"[ICL] Loaded {len(self.train_indices)} training examples with embeddings.")

        
    def get_top_k_examples(self, question, k):
        """Retrieve top-k similar examples using cosine similarity."""
        with torch.no_grad():
            target_embedding = self.embedding_model.encode(question, convert_to_tensor=True).unsqueeze(0)  # [1, dim]
            scores = F.cosine_similarity(target_embedding, self.train_embeddings, dim=1)
            top_k_indices = torch.topk(scores, k=k).indices.tolist()
        return top_k_indices
    

    def build_prompt_template(self, question, top_k_indices):
        """Construct the prompt dict with top-k (q, a) pairs and the target question."""
        prompt_template = {"question": question}
        for i, idx in enumerate(top_k_indices):
            example = self.training_data[self.train_indices[idx]]
            prompt_template[f"q{i+1}"] = example["question"]
            prompt_template[f"a{i+1}"] = example["best_answer"] if 'best_answer' in example else example["answer"]
        return prompt_template
            
    def inference_on_dataset(self, questions, batch_size=16):
        if self.train_data == 'bio':
            k = 3
        else: # truthfulqa, wikiqa
            k = 10
            
        prompts = []

        for question in questions:
            top_k_indices = self.get_top_k_examples(question, k=k)
            prompt_template = self.build_prompt_template(question, top_k_indices)
            prompt = self.prompt_loader.construct_prompt(self.prompt_key, prompt_template)
            chat_input = self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer)
            prompts.append(chat_input)

        responses = self.model.inference_on_data(prompts, batch_size=batch_size)
        
        return responses
    

class DoLaInference:
    def __init__(self, model_path, generation_config=None, prompt_dir="prompt_templates/", prompt_key="zero_shot", device="auto"):
        if generation_config is None:
            generation_config = {}
        generation_config["do_layers"] = "high"
        if "generation_penalty" not in generation_config.keys():
            generation_config["generation_penalty"] = 1.2
        self.model = BaseGenerator(model_path=model_path, generation_config=generation_config, device=device)
        self.prompt_loader = PromptTemplateLoader(template_dir_path=prompt_dir)
        self.prompt_key = prompt_key


    def inference_on_dataset(self, questions, batch_size=16):
        prompts = [self.prompt_loader.construct_prompt(self.prompt_key, {"question": question}) for question in questions]
        prompts = [self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer) for prompt in prompts]
        responses = self.model.inference_on_data(prompts, batch_size=batch_size)
        return responses
    
    def generate(self, prompt, noisy_prompt=None):
        response = self.model.inference_one_sample(prompt=prompt, noisy_prompt=noisy_prompt)
        
        return response


class InstructiveDecoding:
    def __init__(self, model_path, generation_config=None, prompt_dir="prompt_templates/", prompt_key="zero_shot", noisy_prompt_key="opposite_zero", device="auto"):
        self.model = BaseGenerator(model_path=model_path, generation_config=generation_config, device=device)
        self.prompt_loader = PromptTemplateLoader(template_dir_path=prompt_dir)
        self.prompt_key = prompt_key
        self.noisy_prompt_key = noisy_prompt_key


    def inference_on_dataset(self, questions, batch_size=16):
        prompts = [self.prompt_loader.construct_prompt(self.prompt_key, {"question": question}) for question in questions]
        prompts = [self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer) for prompt in prompts]
        
        noisy_prompts = [self.prompt_loader.construct_prompt(self.noisy_prompt_key, {"question": question}) for question in questions]
        noisy_prompts = [self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer) for prompt in noisy_prompts]
        
        responses = self.model.inference_on_data(prompts=prompts, noisy_prompts=noisy_prompts, batch_size=batch_size)
        return responses
    
    
    def instructive_logits(self, logits, noisy_logits, eta=0.3):
        adjusted_logits = logits - eta * noisy_logits
        
        return adjusted_logits
    
    def generate(self, prompt, noisy_prompt):
        response = self.model.inference_one_sample(prompt=prompt, noisy_prompt=noisy_prompt)
        
        return response
        
        
class AdaptiveDecoding: # update from BaseGenerator, not here!
    def __init__(self, model_path, generation_config=None, adaptive_decoding_config={}, 
                 prompt_dir="prompt_templates/", prompt_key="zero_shot", device="auto"):
        self.model = BaseGenerator(model_path=model_path, generation_config=generation_config, adaptive_decoding_config=adaptive_decoding_config, device=device)
        self.prompt_loader = PromptTemplateLoader(template_dir_path=prompt_dir)
        self.prompt_key = prompt_key
        self.adaptive_decoding_config = adaptive_decoding_config

    def inference_on_dataset(self, questions, batch_size=16):
        prompts = [self.prompt_loader.construct_prompt(self.prompt_key, {"question": question}) for question in questions]
        prompts = [self.prompt_loader.construct_chat_input(prompt, tokenizer=self.model.base_tokenizer) for prompt in prompts]
        responses = self.model.inference_on_data(prompts, batch_size=batch_size)
        return responses
    
    def generate(self, prompt, noisy_prompt=None):
        response = self.model.inference_one_sample(prompt=prompt, noisy_prompt=noisy_prompt, adaptive_decoding_config=self.adaptive_decoding_config)
        
        return response
        

def inference(questions, args):
    if args.base_model in MODEL_PATH_DICT.keys():
        args.model_path = MODEL_PATH_DICT[args.base_model]
    else:
        raise ValueError(f"Model {args.base_model} not supported")
    
    if args.method == "greedy":
        if args.temperature == 0.0:
            # set generation config to be greedy decoding
            generation_config = {
                "do_sample": False,
                "top_k": 0,
                "top_p": 1.0,
                "num_beams": 1
                } # some models use their default config
        else:
            generation_config = {"do_sample": True, "temperature": args.temperature}
        
        if args.base_model == "gemma3-4b":
            generation_config = {'do_sample': True, 'cache_implementation': 'hybrid', 'top_k': 64, 'top_p': 0.95, 'bos_token_id': 2}
            
        if args.max_new_tokens:
            generation_config["max_new_tokens"] = args.max_new_tokens
        
        model = VanillaLLM(
            model_path = args.model_path, 
            generation_config=generation_config,
            prompt_dir=args.prompt_dir,
            prompt_key=args.standard_prompt_key,
            device=args.device
            )
            
    elif args.method == "icl":
        if args.temperature == 0.0:
            # set generation config to be greedy decoding
            generation_config = {
                "do_sample": False,
                "top_k": 0,
                "top_p": 1.0,
                "num_beams": 1
                } # some models use their default config
        else:
            generation_config = {"do_sample": True, "temperature": args.temperature}
        
        if args.max_new_tokens:
            generation_config["max_new_tokens"] = args.max_new_tokens
            
        model = IclInference(
            model_path = args.model_path, 
            generation_config=generation_config,
            prompt_dir=args.prompt_dir,
            prompt_key=args.standard_prompt_key,
            device=args.device
            )
            
    elif args.method == "dola":
        model = DoLaInference(
            model_path = args.model_path, 
            prompt_dir=args.prompt_dir,
            prompt_key=args.standard_prompt_key,
            device=args.device
            )
            
    elif args.method == "adaptive":

        model = AdaptiveDecoding(
            adaptive_decoding_config=args.adaptive_decoding_config,
            generation_config={},
            model_path = args.model_path, 
            prompt_dir=args.prompt_dir,
            prompt_key=args.standard_prompt_key,
            device=args.device,
            )
            
    elif args.method in ["instructive", "cad"]: # sharing the same code structure
        
        model = InstructiveDecoding(
            model_path = args.model_path, 
            generation_config={},
            prompt_dir=args.prompt_dir,
            prompt_key=args.standard_prompt_key,
            noisy_prompt_key=args.noisy_prompt_key,
            device=args.device
            ) 
        
    else:
        raise NotImplementedError
    
    responses = model.inference_on_dataset(questions, batch_size=args.batch_size)
    output = []
    for i, question in enumerate(questions):
        output.append({
            "question": question,
            "generated_answer": responses[i],
            "question_index": i,
        })
        
    print(args.output_path)
    with open(args.output_path, "w") as f:
        json.dump(output, f)