import os
import sys
import time

import json
import torch

from tqdm import tqdm

project_root_path = os.environ["PROJECT_PATH"]
sys.path.append(project_root_path)
from Data.load_data import DatasetInfo
from prompt_pool import *
from score import OutputScoreInfo

from typing import Tuple


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
storage_root_path = "/your/path/to/data-feature-storage"

# For Deepseek models: 
start_token, end_token = 151648, 151649
# For LLaMA models: 
# start_token, end_token = 128013, 128014
# For QwQ models: 
# start_token, end_token = 151667, 151668


def retrieve_think_token_index(output_ids: torch.Tensor) -> Tuple[int, int]:
    start_indices = (output_ids == start_token).nonzero(as_tuple=True)[0]
    end_indices = (output_ids == end_token).nonzero(as_tuple=True)[0]

    start_index = int(start_indices[0])
    end_index = int(end_indices[0])

    return start_index, end_index


def extract_hidden_states_pt(hidden_states, output_len):
    layer_num = len(hidden_states[0])  # e.g., 29
    hs_all_layers = []

    for j in range(layer_num):
        token_hs = torch.stack([
            hidden_states[pos][j][0][0].cpu()  # shape: [hidden_size]
            for pos in range(output_len)
        ], dim=0)  # [output_len, hidden_size]
        hs_all_layers.append(token_hs)

    return torch.stack(hs_all_layers, dim=0)  # [num_layers, output_len, hidden_size]


class Inference:
    def __init__(self, model_info: dict, dataset_info: dict):
        self.model_info = model_info
        self.dataset_info = dataset_info

        self.model = self.model_info["model_ckpt"]
        self.model_name = self.model_info["model_name"]
        self.config = self.model_info["model_config"]
        self.generation_config = self.model_info["generation_config"]
        self.tokenizer = self.model_info["tokenizer"]
        self.max_output_token = self.model_info["max_output_token"]

        self.do_sample = self.model_info["do_sample"]
        
        self.dataset_name = self.dataset_info["dataset_name"]

        self.language = self.dataset_info["language"]


        self.dataset_prefix = self.dataset_name.split("_")[0]

        if self.language != "en":
            self.DATASET_PROMPTS_KEY = f"{self.dataset_prefix}_{self.language}"
        else:
            self.DATASET_PROMPTS_KEY = self.dataset_prefix


        self.data_loader = DatasetInfo(self.dataset_name)
        self.data_all = self.data_loader.data
        self.data_size = self.data_loader.data_size

        self.sample_info = {}



    def dataset_inference(self):
        self.greedy_inference()
        

    def greedy_inference(self):
        storage_folder = os.path.join(storage_root_path, self.model_name)
        os.makedirs(storage_folder, exist_ok=True)

        output_folder = os.path.join(storage_folder, self.DATASET_PROMPTS_KEY)
        os.makedirs(output_folder, exist_ok=True)
        
        output_jsonl_path = os.path.join(output_folder, f"{self.dataset_name}.jsonl")

        with open(output_jsonl_path, 'a', encoding='utf-8') as f: 

            for i in tqdm(range(self.data_size)):
                idx = self.data_all[i]["id"]
                print("*"*30 + f" index {str(idx)} " + "*"*30)

                sample = self.data_all[i]

                input_data, output_data, model_input, input_ids = self.parse_input(sample)
                self.sample_info = {
                    "input": {
                        "raw_input_data": input_data,
                        "model_input": model_input,
                        "model_input_ids": input_ids,
                    },
                    "output": {
                        "raw_output_data": output_data,
                    }
                }

                with torch.no_grad():
                    generation_output = self.model_inference()
                    self.sample_info["output"]["all_token_hidden_states"] = generation_output.hidden_states # shape: output_len * 29 * batch(1) * seq_len(1) * 3584
                    self.sample_info["output"]["output_scores"] = generation_output.scores
                    self.sample_info["output"]["output_len"] = min(self.max_output_token, len(generation_output.scores))
                    self.sample_info["output"]["output_seq"] = generation_output.sequences

                    output_ids = generation_output.sequences[0] # cuda tensor

                    
                    if (output_ids == end_token).any():

                        # ========== 1. save output texts ==========

                        output_text = self.tokenizer.decode(generation_output.sequences[0], skip_special_tokens=False)
                        output_seq, maxprob, ppl, entropy = self.print_output()
                        output = {'id': idx,
                                'answer_type': sample["answer_type"] if self.dataset_name == "theoremqa" else "",
                                'input_seq': self.sample_info["input"]["model_input"],
                                'output_seq': output_seq,
                                'output_text': output_text,
                                'maxprob': maxprob,
                                'ppl': ppl,
                                'entropy': entropy}
                        
                        print(f"Writing JSON for ID {idx}")
                        f.write(json.dumps(output, ensure_ascii=False) + '\n')
                        f.flush()

                        start_index, end_index = retrieve_think_token_index(output_ids)
                        cot_length = end_index - start_index
                        print("CoT length: ", cot_length)

                        
                        # hs_tensor_all: [29, output_len, 3584]  ← [layer, token, hidden]
                        hs_tensor_all = extract_hidden_states_pt(self.sample_info["output"]["all_token_hidden_states"], self.sample_info["output"]["output_len"])
                        print("Before slicing: ", hs_tensor_all.shape)

                        path_hs_w_think = os.path.join(
                            storage_folder, self.DATASET_PROMPTS_KEY, 
                            f"{self.DATASET_PROMPTS_KEY}_with_special_token", 
                            f"{self.DATASET_PROMPTS_KEY}_{idx}.pt"
                        )

                        path_hs_wo_think = os.path.join(
                            storage_folder, self.DATASET_PROMPTS_KEY, 
                            f"{self.DATASET_PROMPTS_KEY}_without_special_token", 
                            f"{self.DATASET_PROMPTS_KEY}_{idx}.pt"
                        )

                        path_hs_last_token = os.path.join(
                            storage_folder, self.DATASET_PROMPTS_KEY, 
                            f"{self.DATASET_PROMPTS_KEY}_last_token", 
                            f"{self.DATASET_PROMPTS_KEY}_{idx}.pt"
                        )

                        path_hs_output_tokens = os.path.join(
                            storage_folder, self.DATASET_PROMPTS_KEY, 
                            f"{self.DATASET_PROMPTS_KEY}_output_tokens", 
                            f"{self.DATASET_PROMPTS_KEY}_{idx}.pt"
                        )

                        os.makedirs(os.path.dirname(path_hs_w_think), exist_ok=True)
                        os.makedirs(os.path.dirname(path_hs_wo_think), exist_ok=True)
                        os.makedirs(os.path.dirname(path_hs_last_token), exist_ok=True)
                        os.makedirs(os.path.dirname(path_hs_output_tokens), exist_ok=True)

                        # ========== 2. with special token ==========

                        start_extract_time = time.time()

                        hs_w_think = hs_tensor_all[:, :cot_length + 1, :]  # [29, cot_length+1, 3584]
                        end_extract_time = time.time()
                        print(f"1 - With special token: {hs_w_think.shape} | Extract time: {end_extract_time - start_extract_time:.4f}s")

                        start_save_time = time.time()
                        hs_w_think_cpu = hs_w_think.mean(dim=1).cpu()  # → [29, 3584]
                        torch.save(hs_w_think_cpu, path_hs_w_think)
                        end_save_time = time.time()
                        print(f"Saved with special token | Save time: {end_save_time - start_save_time:.4f}s")

                        # ========== 3. without special token ==========
                        hs_wo_think = hs_tensor_all[:, 1:cot_length, :]  # [29, cot_length-1, 3584]
                        print(f"2 - Without special token: {hs_wo_think.shape}")
                        hs_wo_think_cpu = hs_wo_think.mean(dim=1).cpu()  # → [29, 3584]
                        torch.save(hs_wo_think_cpu, path_hs_wo_think)

                        # ========== 4. last token ==========
                        hs_last_token = hs_tensor_all[:, -1:, :]  # [29, 1, 3584]
                        print(f"3 - Last token: {hs_last_token.shape}")
                        hs_last_token_cpu = hs_last_token.mean(dim=1).cpu()  # → [29, 3584]
                        torch.save(hs_last_token_cpu, path_hs_last_token)

                        # ========== 5. all output tokens ==========
                        print(f"4 - All output tokens: {hs_tensor_all.shape}")
                        hs_output_tokens_cpu = hs_tensor_all.mean(dim=1).cpu()  # → [29, 3584]
                        torch.save(hs_output_tokens_cpu, path_hs_output_tokens)

                        del self.sample_info["output"]["all_token_hidden_states"]
                        del hs_w_think
                        del hs_wo_think
                        del hs_last_token
                        del hs_tensor_all
                        torch.cuda.empty_cache()


    def model_inference(self):
        input_ids = self.sample_info["input"]["model_input_ids"]
        self.model.eval()
        terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")] \
            if "Llama" in self.model_name else self.tokenizer.eos_token_id
    
        time_start = time.time()
        generation_output = self.model.generate(
            input_ids=input_ids.to(device),
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=terminators,
            generation_config=self.generation_config,
            return_dict_in_generate=True,
            max_new_tokens=self.max_output_token,
            output_attentions=False,
            output_hidden_states=True,
            output_scores=True,
            do_sample=self.do_sample,
        )
        time_end = time.time()
        print(f'inference time: {round(time_end - time_start, 4)}')

        output_ids = generation_output.sequences
        print('\n')
        print("Output IDs:", output_ids)
        
        return generation_output


    def parse_input(self, sample):

        input_data = sample[self.language]
        output_data = sample["answer"]

        model_input = DATASET_PROMPTS[self.DATASET_PROMPTS_KEY].replace("{input_data}", input_data)

        if self.dataset_prefix == "theoremqa":
            model_input = model_input.replace("{answer_type}", sample["answer_type"])
        input_ids = self.tokenizer.apply_chat_template([{"role": "user", "content": model_input}], 
                            tokenize=True, add_generation_prompt=True, return_tensors="pt")
        input_len = len(input_ids[0])

        print(f"********** Input Text (length: {input_len}) **********\n{input_data}\n")
        print(f"********** Input ID **********\n{input_ids}\n")
        
        return input_data, output_data, model_input, input_ids
    

    def print_output(self):
        output_scores = self.sample_info["output"]["output_scores"]
        output_seq = self.sample_info["output"]["output_seq"]
        true_output = self.sample_info["output"]["raw_output_data"]
        output_len = self.sample_info["output"]["output_len"]

        output_seq = self.tokenizer.decode(output_seq[0][-output_len:])
        print(f"********** Model-generated Text (length: {output_len}) **********\n{output_seq}\n")
        print(f"********** True Output Text **********\n{true_output}\n")

        outputinfo = OutputScoreInfo(output_scores)
        maxprob = outputinfo.compute_maxprob()
        ppl = outputinfo.compute_ppl()
        entropy = outputinfo.compute_entropy()
        print(f"********** Output Info: **********\nmaxprob {maxprob}; perplexity {ppl}; entropy {entropy}\n")

        return output_seq, maxprob, ppl, entropy