import torch
import torch.nn.functional as F
import params
import os
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# --------------------------
# RCD Functions
# --------------------------
def get_context_logits(
    sims,
    data,
    all_labels,
    top_n=10,
    gamma=0.01,
    sim_threshold=0,
    temperature=1,
    vocab_size=None,
    device=None,
):
    indices = [i for i, label in enumerate(all_labels) if label == 'correct' and sims[i].item() > sim_threshold]

    if not indices:
        return torch.zeros(vocab_size, dtype=torch.float32, device=device)

    sim_values = sims[indices]  # Tensor [num_valid]
    
    # Top-N selection
    sorted_pairs = sorted(zip(indices, sim_values.tolist()), key=lambda x: -x[1])
    top = sorted_pairs[:top_n]
    if not top:
        return torch.zeros(vocab_size, dtype=torch.float32, device=device)
    indices = [i for i, _ in top]
    sim_values = torch.tensor([sim for _, sim in top], dtype=torch.float32).to(device)

    if device is None and indices:
        device = data[indices[0]]["logits"].device

    # Softmax and thresholding
    weights = F.softmax(sim_values / temperature, dim=0)
    mask = weights > gamma
    if mask.sum() == 0:
        # If no weights pass the threshold, return zero logits
        return torch.zeros(vocab_size, dtype=torch.float32, device=device)
    
    weights = weights[mask]
    indices = [indices[i] for i, keep in enumerate(mask) if keep]

    logits_list = [data[i]["logits"].to(dtype=torch.float32).to(device) for i in indices]
    weights = weights / (weights.sum() + 1e-8)

    return (torch.stack(logits_list) * weights.unsqueeze(1)).sum(dim=0)

def adaptive_generation(
    model_name,
    model,
    tokenizer,
    embed_model,
    input_ids,
    alpha=1.0,
    max_new_tokens=32,
    chunk_mode='chunk_8',
    embed_model_name='all-MiniLM-L6-v2',
    top_n=10,  
    gamma=0.01, 
    num_train=100,
    train_data=''
):
    # Load the training data
    data_path = os.path.join(params.output_dir, f"{train_data}{model_name}_{embed_model_name}_context_{chunk_mode}.pt")
    data = torch.load(data_path)
    
    # Filter data to only include the first num_train unique question_ids
    seen = set()
    question_ids = []
    for item in data:
        qid = item['question_id']
        if qid not in seen:
            seen.add(qid)
            question_ids.append(qid)
        if len(seen) >= num_train:
            break
        
    # Take all labels with question_id in top num_train question_ids
    device = next(model.parameters()).device
    all_embeddings = torch.stack([item['context_embedding'] for item in data if item['question_id'] in seen]).to(device)
    all_labels = [item['label'] for item in data if item['question_id'] in seen]
    
    model.eval()
    vocab_size = getattr(model.config, 'vocab_size', model.get_input_embeddings().num_embeddings)

    generated = input_ids.clone()
    chunk_size = int(chunk_mode.split('_')[1]) if chunk_mode.startswith('chunk_') else None

    
    for _ in range(max_new_tokens):
        
        with torch.no_grad():
            outputs = model(generated)

        base_logit = outputs.logits[0, -1, :]  # [vocab_size]

        if chunk_mode == 'full':
            context_input_ids = generated[0]
        elif chunk_mode.startswith('chunk_'):
            context_input_ids = generated[0][-chunk_size:]
        else:
            raise Exception(f'{chunk_mode} is not supported')

        context_text = tokenizer.decode(context_input_ids, skip_special_tokens=True)
        context_embedding = F.normalize(embed_model.encode(context_text, convert_to_tensor=True), dim=-1).to(dtype=torch.float32).to(device)

        sims = F.cosine_similarity(context_embedding.unsqueeze(0), all_embeddings)

        pos_logit = get_context_logits(
            sims=sims,
            data=data,
            all_labels=all_labels,
            top_n=top_n,
            gamma=gamma,
            vocab_size=vocab_size,
            device=device,
        )
        
        adjusted_logits = base_logit + alpha * pos_logit
        best_token = torch.argmax(adjusted_logits)
        generated = torch.cat([generated, best_token.view(1, 1)], dim=1)

        if tokenizer.eos_token_id is not None and best_token.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)

    return output_text


def adaptive_generation_batch(
    model_name,
    model,
    tokenizer,
    embed_model,
    input_ids_batch,
    alpha=1.0,
    max_new_tokens=32,
    chunk_mode='chunk_8', 
    top_n=10,
    embed_model_name='all-MiniLM-L6-v2',
    gamma=0.01, 
    num_train=100,
    train_data=''
):
    all_outputs = []

    for input_ids in input_ids_batch:
        input_ids = input_ids.unsqueeze(0)  # Add batch dim back

        output_text = adaptive_generation(
            model_name=model_name,
            model=model,
            tokenizer=tokenizer,
            embed_model=embed_model,
            input_ids=input_ids,
            alpha=alpha,
            max_new_tokens=max_new_tokens,
            chunk_mode=chunk_mode, 
            top_n=top_n,
            embed_model_name=embed_model_name,
            gamma=gamma, 
            num_train=num_train,
            train_data=train_data
        )
        all_outputs.append(output_text)

    return all_outputs


# --------------------------
# Instructive Decoding (ID) Functions
# --------------------------
def instructive_generation(
    model,
    tokenizer,
    input_ids,
    noisy_input_ids,
    eta=0.3,
    max_new_tokens=32,
):

    generated = input_ids.clone()
    noisy_generated = noisy_input_ids.clone()

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(generated)
            noisy_outputs = model(noisy_generated)

        next_token_logits = outputs.logits[:, -1, :]  # [1, vocab]
        noisy_token_logits = noisy_outputs.logits[:, -1, :]  # [1, vocab]

        adjusted_logits = next_token_logits - eta * noisy_token_logits
        adjusted_probs = F.softmax(adjusted_logits, dim=-1)
        best_token = torch.argmax(adjusted_probs, dim=-1)  # [1]

        next_token = best_token.unsqueeze(0)  # [1, 1]
        generated = torch.cat([generated, next_token], dim=1)

        if tokenizer.eos_token_id is not None and best_token.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
    return output_text


def instructive_generation_batch(
    model,
    tokenizer,
    input_ids_batch,
    noisy_input_ids_batch,
    eta=0.3, 
    max_new_tokens=32,
):
    responses = []
    for input_ids, noisy_input_ids in zip(input_ids_batch, noisy_input_ids_batch):
        input_ids = input_ids.unsqueeze(0)  
        noisy_input_ids = noisy_input_ids.unsqueeze(0) 
        
        output_text = instructive_generation(
            model=model,
            tokenizer=tokenizer,
            input_ids=input_ids,
            noisy_input_ids=noisy_input_ids,
            eta=eta,
            max_new_tokens=max_new_tokens,
        )
        responses.append(output_text)

    return responses

# --------------------------
# CAD Inference 
# --------------------------
def cad_generation(
    model,
    tokenizer,
    input_ids,
    noisy_input_ids,
    eta=0.5,  # alpha for context-aware decoding
    max_new_tokens=32,
):

    generated = input_ids.clone()
    noisy_generated = noisy_input_ids.clone()

    for step in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(generated)
            noisy_outputs = model(noisy_generated)

        next_token_logits = outputs.logits[:, -1, :]  # [1, vocab]
        noisy_token_logits = noisy_outputs.logits[:, -1, :]  # [1, vocab]

        adjusted_logits = next_token_logits + eta * (next_token_logits - noisy_token_logits)
        adjusted_probs = F.softmax(adjusted_logits, dim=-1)
        best_token = torch.argmax(adjusted_probs, dim=-1)  # [1]

        next_token = best_token.unsqueeze(0)  # [1, 1]
        generated = torch.cat([generated, next_token], dim=1)

        if tokenizer.eos_token_id is not None and best_token.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
    return output_text


def cad_generation_batch(
    model,
    tokenizer,
    input_ids_batch,
    noisy_input_ids_batch,
    eta=0.5,  # alpha for context-aware decoding
    max_new_tokens=32,
):
    responses = []
    for input_ids, noisy_input_ids in zip(input_ids_batch, noisy_input_ids_batch):
        input_ids = input_ids.unsqueeze(0)  # Add batch dim back
        noisy_input_ids = noisy_input_ids.unsqueeze(0)  # Add batch dim back
        
        output_text = cad_generation(
            model=model,
            tokenizer=tokenizer,
            input_ids=input_ids,
            noisy_input_ids=noisy_input_ids,
            eta=eta,
            max_new_tokens=max_new_tokens,
        )
        responses.append(output_text)

    return responses


def load_model_from_path(base_model_path, device):
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, device=device, padding_side="left", trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    if "gemma3" in base_model_path.lower():
        from transformers import Gemma3ForCausalLM
        model = Gemma3ForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map=device, trust_remote_code=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map=device, trust_remote_code=True)
    
    return model, tokenizer

# --------------------------
# Base Generator Class
# --------------------------
class BaseGenerator:
    def __init__(self, model_path, device="cuda:0", generation_config=None, adaptive_decoding_config={}):
        self.base_model_path = model_path
        self.base_model, self.base_tokenizer = load_model_from_path(base_model_path=self.base_model_path, device=device)
        # self.base_tokenizer.pad_token = self.base_tokenizer.eos_token
        # self.base_tokenizer.padding_side = 'left'
        self.base_model.eval()
        if device == "auto":
            self.device = 0
        else:
            self.device = int(device[5:])
        print(f"self.device: {self.device}")
        # self.base_model.to(device)
        if generation_config is None:
            generation_config = {}
        if "max_new_tokens" not in generation_config.keys():
            generation_config['max_new_tokens'] = 256
            self.max_new_tokens = 256
            
        generation_config["pad_token_id"] = self.base_tokenizer.pad_token_id
        if model_path.find('glm-4') != -1:
            generation_config["eos_token"] = "<|user|>"
            generation_config["eos_token_id"] = 151336
        else:
            generation_config["eos_token"] = self.base_tokenizer.eos_token
            generation_config["eos_token_id"] = self.base_tokenizer.eos_token_id
        self.generation_config = GenerationConfig(**generation_config)
        print("Generation config: ", self.generation_config)

        self.adaptive_decoding_config = adaptive_decoding_config


    def inference_on_data(self,
                          prompts,
                          batch_size=16, # change to 1 for debugging
                          noisy_prompts=None):
        responses = []
        step = 0
        for i in tqdm(range(0, len(prompts), batch_size)):
            step += 1
            batch = prompts[i:i+batch_size]
            batch_responses = self.inference_on_one_batch(batch, adaptive_decoding_config=self.adaptive_decoding_config, noisy_prompts=noisy_prompts)
            responses.extend(batch_responses)
        
        return responses

    def inference_on_one_batch(self, prompts, adaptive_decoding_config={}, noisy_prompts=None):
        base_inputs = self.base_tokenizer(prompts, return_tensors="pt", padding=True)
        base_inputs = {k: v.to(self.device) for k, v in base_inputs.items()}
                
        if noisy_prompts is not None:
            noisy_inputs = self.base_tokenizer(noisy_prompts, return_tensors="pt", padding=True)
            noisy_inputs = {k: v.to(self.device) for k, v in noisy_inputs.items()}

            if 'opposite' in noisy_prompts: # Intructive Decoding
                responses = instructive_generation_batch(
                    model=self.base_model,
                    tokenizer=self.base_tokenizer,
                    input_ids_batch=base_inputs['input_ids'],
                    noisy_input_ids_batch=noisy_inputs['input_ids'],
                    eta=0.3,  # eta for instructive decoding
                    max_new_tokens=self.max_new_tokens, 
                )
            elif 'cad' in noisy_prompts: # CAD
                responses = cad_generation_batch(
                    model=self.base_model,
                    tokenizer=self.base_tokenizer,
                    input_ids_batch=base_inputs['input_ids'],
                    noisy_input_ids_batch=noisy_inputs['input_ids'],
                    eta=0.5,  # alpha for CAD
                    max_new_tokens=self.max_new_tokens, 
                )

        if len(adaptive_decoding_config) > 0:
            responses = adaptive_generation_batch(
                model=self.base_model,
                tokenizer=self.base_tokenizer,
                model_name=adaptive_decoding_config.get('model_name'),
                input_ids_batch=base_inputs['input_ids'],
                train_data=adaptive_decoding_config.get('train_data', ''),
                embed_model=adaptive_decoding_config.get('embed_model'),
                embed_model_name=adaptive_decoding_config.get('embed_model_name', 'all-MiniLM-L6-v2'),
                top_n=adaptive_decoding_config.get('top_n', 10),
                alpha=adaptive_decoding_config.get('alpha', 0.5),
                max_new_tokens=adaptive_decoding_config.get('max_new_tokens', 256),
                chunk_mode=adaptive_decoding_config.get("chunk_mode", 'chunk_8'),
                gamma=adaptive_decoding_config.get("gamma", 0.01),       
                num_train=adaptive_decoding_config.get("num_train", 100), 
                )
        else:
            outputs = self.base_model.generate(
                **base_inputs,
                generation_config=self.generation_config,
                output_scores=True, # 
                return_dict_in_generate=True, #
                )

            # Decode generated text (only new tokens after input)
            generated_sequences = outputs.sequences  # (batch_size, input_len + generated_len)
            input_len = base_inputs["input_ids"].shape[-1]
            generated_tokens = generated_sequences[:, input_len:]  # Only generated part
            responses = self.base_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        return responses

    def inference_one_sample(self, prompt, adaptive_decoding_config={}, noisy_prompt=None):
        base_inputs = self.base_tokenizer([prompt], return_tensors="pt", padding=True)
        base_inputs = {k: v.to(self.device) for k, v in base_inputs.items()}

        if noisy_prompt is not None:
            noisy_inputs = self.base_tokenizer(noisy_prompt, return_tensors="pt", padding=True)
            noisy_inputs = {k: v.to(self.device) for k, v in noisy_inputs.items()}

            if 'opposite' in noisy_prompt: # Intructive Decoding
                responses = instructive_generation_batch(
                    model=self.base_model,
                    tokenizer=self.base_tokenizer,
                    input_ids=base_inputs['input_ids'],
                    noisy_input_ids=noisy_inputs['input_ids'],
                    eta=0.3,  # eta for instructive decoding
                    max_new_tokens=self.max_new_tokens,
                )
            elif 'cad' in noisy_prompt: # CAD
                responses = cad_generation_batch(
                    model=self.base_model,
                    tokenizer=self.base_tokenizer,
                    input_ids=base_inputs['input_ids'],
                    noisy_input_ids=noisy_inputs['input_ids'],
                    eta=0.5,  # alpha for CAD
                    max_new_tokens=self.max_new_tokens,
                )

        if len(adaptive_decoding_config) > 0:
            response = adaptive_generation(
                model=self.base_model,
                tokenizer=self.base_tokenizer,
                model_name=adaptive_decoding_config.get('model_name'),
                input_ids=base_inputs['input_ids'],
                train_data=adaptive_decoding_config.get('train_data', ''),
                embed_model=adaptive_decoding_config.get('embed_model'),
                embed_model_name=adaptive_decoding_config.get('embed_model_name', 'all-MiniLM-L6-v2'),
                top_n=adaptive_decoding_config.get('top_n', 10),
                alpha=adaptive_decoding_config.get('alpha', 0.5), 
                max_new_tokens=adaptive_decoding_config.get('max_new_tokens', 256),
                chunk_mode=adaptive_decoding_config.get("chunk_mode", 'chunk_8'),
                gamma=adaptive_decoding_config.get("gamma", 0.01),     
                num_train=adaptive_decoding_config.get("num_train", 100), 
                )

            return response
        else:
            outputs = self.base_model.generate(
                **base_inputs,
                generation_config=self.generation_config)

            responses = self.base_tokenizer.batch_decode(outputs[:, base_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
            return responses[0]