import torch
from transformers import Qwen2Tokenizer, Qwen2ForCausalLM
from joint_model import Qwen2JointConfig, Qwen2ForJointLM, print_time
from torch.utils.data import DataLoader, Dataset
from typing import List, Tuple, Any
import logging
from utils.data_split import read_jsonl, write_jsonl
from collections import defaultdict as ddict
from examples.gsm8k.evaluation.metrics import parse_answer_2, extract_answers, em
from tqdm import tqdm
import argparse
from examples.gsm8k.po_rewrite.test_po_ft import process_generated_prompt
from peft import PeftModel, PeftConfig
import time


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class InferenceDataset(Dataset):
    def __init__(self, meta_prompts: List[str], questions: List[str]):
        assert len(meta_prompts) == len(questions), "meta_prompts and questions must have the same length."
        self.meta_prompts = meta_prompts
        self.questions = questions

    def __len__(self):
        return len(self.meta_prompts)

    def __getitem__(self, idx):
        return {
            "meta_prompt": self.meta_prompts[idx],
            "question": self.questions[idx]
        }


def batch_inference(
    model: Qwen2ForJointLM,
    tokenizer: Qwen2Tokenizer,
    meta_prompts: List[str],
    questions: List[str],
    device: torch.device,
    batch_size: int = 8,
    pg_t: float = 0,
    pg_top_p: float = 1,
    pg_do_sample: bool = False,
) -> tuple[list[Any], list[Any]]:

    dataset = InferenceDataset(meta_prompts, questions)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    model.eval()
    model.to(device)

    all_answers = []
    all_prompts = []

    batch_id = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            print(batch_id)
            batch_id += 1

            meta_prompt_list = batch["meta_prompt"]
            question_list = batch["question"]

            meta_encodings = tokenizer(
                meta_prompt_list,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=512
            ).to(device)
            meta_input_ids = meta_encodings["input_ids"].to(device)
            meta_attention_mask = meta_encodings["attention_mask"].to(device)

            outputs1 = model.model1(
                input_ids=meta_input_ids,
                attention_mask=meta_attention_mask,
                output_hidden_states=True
            )

            hidden_states = outputs1.hidden_states
            last_hidden_state = hidden_states[-1]
            condition_var = last_hidden_state.mean(dim=1)
            condition_var = model.hypernets[0].condition_proj(condition_var)
            for hypernet in model.hypernets:
                hypernet.up_hypernet.set_condition_var(condition_var)
                hypernet.down_hypernet.set_condition_var(condition_var)

            generation_kwargs = {
                "max_new_tokens": 100,
                "temperature": pg_t,
                "top_p": pg_top_p,
                "num_return_sequences": 1,
                "num_beams": 1,
                "do_sample": pg_do_sample,
                "repetition_penalty": 1.0,
                "pad_token_id": tokenizer.eos_token_id
            }

            generated_ids = model.model1.generate(
                input_ids=meta_input_ids,
                attention_mask=meta_attention_mask,
                **generation_kwargs
            )

            decoded_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            generated_texts = [t.replace(mp, '') for mp, t in zip(meta_prompt_list, decoded_texts)]
            all_prompts.extend(generated_texts)

            concatenated_texts = [
                process_generated_prompt(generated_prompt, question)
                for generated_prompt, question in zip(generated_texts, question_list)
            ]

            concatenated_encodings = tokenizer(
                concatenated_texts,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=512
            )
            concatenated_input_ids = concatenated_encodings["input_ids"].to(device)
            concatenated_attention_mask = concatenated_encodings["attention_mask"].to(device)

            generation_kwargs = {
                "max_new_tokens": 256,
                "temperature": 0, 
                "top_p": 1,
                "num_return_sequences": 1,
                "num_beams": 1,
                "do_sample": False,
                "repetition_penalty": 1.0,
                "pad_token_id": tokenizer.eos_token_id
            }
            generated_ids = model.model2.generate(
                input_ids=concatenated_input_ids,
                attention_mask=concatenated_attention_mask,
                **generation_kwargs
            )
            generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            final_answers = [t.replace(c, '') for c, t in zip(concatenated_texts, generated_texts)]
            all_answers.extend(final_answers)

    return all_prompts, all_answers


def hyper_eval(data_path, predict_path, score_path, acc_path):
    data = read_jsonl(data_path)
    predicts_full = read_jsonl(predict_path)

    correct = 0
    predicts_with_scores = []
    gts = [parse_answer_2(d['answer']) for d in data]
    predicts = [extract_answers(d['predict_full']) for d in predicts_full]
    for p, g in zip(predicts, gts):
        if p and em(p, g):
            p_score = 1
            correct += 1
        else:
            p_score = 0
        predicts_with_scores.append({'score': p_score})
    
    write_jsonl(predicts_with_scores, score_path)

    acc_str = f"Accuracy: {correct} / {len(data)} = {correct / len(data)}"
    with open(acc_path, 'a+', encoding='utf-8') as f:
        f.write("------------------------------------\n")
        f.write(predict_path + '\n' + score_path + '\n' + acc_str + '\n\n')

    print(acc_str)
    

def hyper_inference_and_eval(
        data_path,
        base_output_path,
        base_model_path,
        model_name,
        base_model1_path,
        model1_lora_name,
        generate_time,
        pg_t,
        pg_top_p,
        pg_do_sample
):
    prompt_path = f"{base_output_path}/{generate_time}_{pg_t}_prompts_{model_name}.jsonl"
    predict_path = f"{base_output_path}/{generate_time}_{pg_t}_predicts_{model_name}.jsonl"
    score_path = f"{base_output_path}/{generate_time}_{pg_t}_scores_{model_name}.jsonl"
    acc_path = f"{base_output_path}/acc.txt"

    model_path = f"{base_model_path}/{model_name}"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Qwen2ForJointLM.from_pretrained(model_path)
    tokenizer = model.tokenizer

    if model1_lora_name:
        model.model1 = PeftModel.from_pretrained(Qwen2ForCausalLM.from_pretrained(base_model1_path), model1_lora_name)

    data = read_jsonl(data_path)
    meta_prompts = [d['meta_prompt'] for d in data]
    questions = [d['question'] for d in data]

    print("Starting batch inference...")
    generated_prompts, generated_answers = batch_inference(
        model=model,
        tokenizer=tokenizer,
        meta_prompts=meta_prompts,
        questions=questions,
        device=device,
        batch_size=4,
        pg_t=pg_t,
        pg_top_p=pg_top_p,
        pg_do_sample=pg_do_sample
    )

    prompts_full = [{"generated_prompt": x} for x in generated_prompts]
    predicts_full = [{"predict_full": x} for x in generated_answers]
    write_jsonl(prompts_full, prompt_path)
    write_jsonl(predicts_full, predict_path)

    hyper_eval(data_path, predict_path, score_path, acc_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='hyper_inference')

    # data_path
    parser.add_argument("--test_data_path", type=str, default='../data/gsm8k/test_rewrite.jsonl')
    parser.add_argument("--base_output_path", type=str, default="../data/gsm8k/results")

    # model_path
    parser.add_argument("--base_model_path", type=str, default='../data/gsm8k/adapters')
    parser.add_argument("--model_name", type=str, default='hyper_params_default')
    parser.add_argument("--base_model1_path", type=str, default='')
    parser.add_argument("--model1_lora_name", type=str, default=None)

    # inference_cofig
    parser.add_argument("--generate_time", type=str, default='0')
    parser.add_argument("--pg_t", type=float, default=0)
    parser.add_argument("--pg_top_p", type=float, default=1)
    parser.add_argument("--pg_do_sample", type=bool, default=False)

    args = parser.parse_args()
    print('Args in experiment:')
    print(args)

    hyper_inference_and_eval(
        data_path=args.test_data_path,
        base_output_path=args.base_output_path,
        base_model_path=args.base_model_path,
        model_name=args.model_name,
        base_model1_path=args.base_model1_path,
        model1_lora_name=args.model1_lora_name,
        generate_time=args.generate_time,
        pg_t=args.pg_t,
        pg_top_p=args.pg_top_p,
        pg_do_sample=args.pg_do_sample,
    )

    print("done.")
