import gc
import torch
import pandas as pd 
from vllm import LLM
from peft import PeftModel
from tqdm.autonotebook import tqdm 

from .model_utils import create_tokenizer
from .prompt_hub import system_prompts, confidence_prompts
from .data_utils import LabeledStringDataCollator, get_loader


def wrapped_generate_output(
    model,
    tokenizer,
    generation_inputs,
    generation_config,
    track_logits=False,
    target_idx=None):
    
    generation_outputs = model.generate(
        **generation_inputs, 
        eos_token_id=[tokenizer.eos_token_id],
        generation_config=generation_config,
        return_dict_in_generate=True if track_logits else False,
        output_scores=True if track_logits else False,
    )
    if track_logits:
        logits = torch.stack(logits, dim=1)
        if target_idx is not None:
            logit_values = logits[:, :, target_idx]
        else:
            logit_values, target_idx = torch.topk(logits, k=10, dim=-1)
            target_idx = target_idx.detach().cpu().tolist()
        logit_values = logit_values.detach().cpu().tolist()
        return (generation_outputs.sequences, logit_values, target_idx)
    else:
        return generation_outputs


def generate_outputs(
    accelerator,
    model,
    tokenizer,
    system_prompt,
    loader,
    generation_config,
    input_col_name="prompt",
    skip_template=False,
    think_mode=True,
    track_logits=False,
    target_idx=None,):
    
    collate_fn = LabeledStringDataCollator(tokenizer, 
                                           system_prompt=system_prompt,
                                           skip_template=skip_template,
                                           think_mode=think_mode)

    results = []
    if track_logits:
        results_logits, results_idx = [], []
    for inputs in tqdm(loader):
        inputs = inputs[input_col_name]
        generation_inputs = {
            k: v.to(accelerator.device) for k, v in collate_fn(inputs).items()
            }

        if isinstance(model, PeftModel):
            model.set_adapter("query")
        
        all_outputs = wrapped_generate_output(
            model,
            tokenizer,
            generation_inputs,
            generation_config,
            track_logits,
            target_idx,
            )
        
        if track_logits:
            generation_outputs, logit_values, target_idx = all_outputs
            results_logits.extend(logit_values)
            results_idx.extend(target_idx)
        else:
            generation_outputs = all_outputs
            
        generations = tokenizer.batch_decode(
            generation_outputs[:, generation_inputs.get("input_ids").size(-1) :],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
            )
        results.extend(generations)

        del generation_outputs
        gc.collect()
        torch.cuda.empty_cache()
    
    if track_logits:
        results = {
            "outputs": results,
            "logits": results_logits,
            "target_idx": results_idx
        }
    return results  


def generate_outputs_vllm(
    model_name,
    base_prompt,  
    dataset,
    batch_size,
    sampling_params,
    skip_template=False,
    suffix=False,
    think_mode=True):

    if "qwen2" in model_name:
        system_prompt = system_prompts['qwen']
    elif "llama" in model_name:
        system_prompt = system_prompts['llama']
    else:
        system_prompt = None
        
    tokenizer = create_tokenizer(model_name)
    model = LLM(model_name)
    
    if not skip_template:
        msgs = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": base_prompt}]
        
        if "qwen3" in model_name.lower():
            msgs = [{"role": "user", "content": base_prompt}]
            base_prompt = tokenizer.apply_chat_template(
                msgs, 
                tokenize=False, 
                add_generation_prompt=True,
                enable_thinking=think_mode
            )
        
        else:            
            base_prompt = tokenizer.apply_chat_template(
                msgs, 
                tokenize=False, 
                add_generation_prompt=True
                )

    input_prompts = [base_prompt.replace("<question>", q) for q in dataset]

    ## Generate outputs.
    answers = []
    for idx in tqdm(range(0, len(input_prompts), batch_size)):
            
        if idx + batch_size > len(input_prompts):
            batch_prompt = input_prompts[idx:]
        else:
            batch_prompt = input_prompts[idx:idx+batch_size]
        
        outputs = model.generate(batch_prompt, 
                                 sampling_params,
                                 use_tqdm=False)
        
        outputs = [outputs[i].outputs[0].text for i in range(len(outputs))]
        answers.extend(outputs)
    
        if idx == 0:
            print(f"Input prompt:\n {input_prompts[0]}")
            print(f"Output:\n {answers[0]}")
            
    df = pd.DataFrame({
        "question": input_prompts,
        "pred_answer": answers
    })
    
    if suffix:
        print("Generating confidence scores...")
        conf_inputs = [q + p + tokenizer.apply_chat_template(
            [{"role": "user", "content": confidence_prompts['suffix']}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=think_mode) 
                       for q, p in zip(input_prompts, answers)]
        conf_outputs = []
        for idx in tqdm(range(0, len(conf_inputs), batch_size)):
                
            if idx + batch_size > len(conf_inputs):
                batch_prompt = conf_inputs[idx:]
            else:
                batch_prompt = conf_inputs[idx:idx+batch_size]
            
            outputs = model.generate(batch_prompt, 
                                    sampling_params,
                                    use_tqdm=False)
            
            outputs = [outputs[i].outputs[0].text for i in range(len(outputs))]
            conf_outputs.extend(outputs)
        
            if idx == 0:
                print(f"Input prompt:\n {conf_inputs[0]}")
                print(f"Output:\n {conf_outputs[0]}")
                
        df['conf_input'] = conf_inputs
        df['conf'] = conf_outputs 
    return df


def generate_outputs_base(
    accelerator,
    model,
    tokenizer,
    base_prompt,
    dataset,
    batch_size,
    generation_config,
    skip_template=False,
    track_logits=False,
    target_idx=None,
    suffix=False,
    think_mode=True):
    # If there is no pred_answer column, generate it.
    if "qwen2" in model.name_or_path.lower():
        system_prompt = system_prompts['qwen']
    elif "llama" in model.name_or_path.lower():
        system_prompt = system_prompts['llama']
    else:
        system_prompt = None
    
    dataset = dataset.map(
        lambda x: {"question": 
            base_prompt.replace("<question>", x["question"])
            })
    
    loader = get_loader(dataset, batch_size=batch_size,
                pin_memory=True, accelerator=accelerator)
    
    outputs = generate_outputs(
        accelerator,
        model,
        tokenizer,
        system_prompt,
        loader,
        generation_config,
        "question",
        skip_template,
        think_mode,
        track_logits=track_logits,
        target_idx=None)
    
    # modify for logit mode
    df = pd.DataFrame({
        "question": [d["question"] for d in dataset],
        "pred_answer": outputs,
        "true_answer": [d["answer"] for d in dataset]
        })
    print(outputs[0])
    dataset = dataset.add_column("pred_answer", outputs)
    
    if suffix:
        print("Generating confidence scores...")
        torch.cuda.empty_cache()
        import gc; gc.collect()
        
        conf_dataset = dataset.map(
            lambda x: {"conf_input":
                tokenizer.apply_chat_template(
                    [{"role": "system", "content": system_prompt}
                     ,{"role": "user", "content": x["question"]}],
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=think_mode
                    ) 
                + x["pred_answer"] + tokenizer.apply_chat_template(
            [{"role": "user", "content": confidence_prompts['suffix']}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=think_mode)
            }
        )

        conf_loader = get_loader(conf_dataset, batch_size=batch_size,
                    pin_memory=True, accelerator=accelerator)
        
        conf_outputs = generate_outputs(
            accelerator,
            model,
            tokenizer,
            conf_loader,
            generation_config,
            "conf_input",
            skip_template=True,
            think_mode=think_mode,
            track_logits=track_logits,
            target_idx=target_idx
        )   
        
        if track_logits:
            df['conf'] = conf_outputs['outputs']
            df['conf_logits'] = conf_outputs['logits']
            df['conf_target_idx'] = conf_outputs['target_idx']
            print(conf_outputs['outputs'][0])
        else:
            df['conf'] = conf_outputs
            print(conf_outputs[0])
        
    return df 