import os
import json
import einops
import torch
import re

from .utils import (
    generate_api,
    get_mask_by_text,
    get_mask_by_role,
    apply_repetition_penalty,
    preprocess_ctx,
)
from ..utils import remove_date_block, DATE_BLOCK_PATTERN
from ..entropy.calculate_entropy import calculate_entropy
from .cad_baseline import CAD

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def generate_greedy_response(tokenizer, model, model_family, input_ids, args):
    input_len = input_ids.input_ids.shape[1]

    with torch.no_grad():
        greedy_output = generate_api(
            tokenizer,
            model,
            input_ids,
            False,
            None,
            input_len + args.response_len,
            model_family != "gemma3",
        )
        output_ids = greedy_output.sequences[0][input_len:]
        greedy_response = tokenizer.decode(output_ids, skip_special_tokens=True)
        response_len = len(output_ids)
    return {"response": greedy_response, "length_response": response_len, "length_full": input_len + response_len}


def generate_nucleus_response(tokenizer, model, model_family, input_ids, args):
    input_len = input_ids.input_ids.shape[1]

    with torch.no_grad():
        nucleus_output = generate_api(
            tokenizer,
            model,
            input_ids,
            True,
            0.95,
            input_len + args.response_len,
            model_family != "gemma3",
        )
        output_ids = nucleus_output.sequences[0][input_len:]
        nucleus_response = tokenizer.decode(output_ids, skip_special_tokens=True)
        response_len = len(output_ids)
    return {"response": nucleus_response, "length_response": response_len, "length_full": input_len + response_len}


def select_best_token_heads(
    model, input_embeds, last_token_logits, top_probs, top_indices, attn_indexes, args, relevance_mask=None
):
    if len(top_indices) == 1:  # no need to compare if only one token
        return top_indices[0]

    relevance_scores = []
    for token_id in top_indices:
        token_logit = last_token_logits[token_id]
        token_logit.backward(retain_graph=True)
        sum_rel = 0.0
        for attn_idx in attn_indexes[: args.top_heads_num]:
            layer = attn_idx // model.config.num_attention_heads
            head_idx = attn_idx % model.config.num_attention_heads
            if relevance_mask is None:
                act = model.model.layers[layer].self_attn.attn_output
                rel = (act * act.grad).float()  # relevance
                head_rel = rel[..., head_idx, :]
                if args.clamp:
                    head_rel = head_rel.clamp(min=0)
                sum_rel += head_rel.sum().item()

            else:
                attn_weights = model.model.layers[layer].self_attn.attn_weights
                rel = attn_weights.grad * attn_weights
                head_rel = rel[:, head_idx, ...].float()
                if args.clamp:
                    head_rel = head_rel.clamp(min=0)
                key_rel = head_rel.sum(1)
                context_rel = key_rel[0, relevance_mask == 1]
                sum_rel += context_rel.sum().item()

        for layer in model.model.layers:
            layer.self_attn.attn_weights.grad = None  # reset gradients for next token
            layer.self_attn.attn_output.grad = None  # reset gradients for next token
        input_embeds.grad.zero_()
        relevance_scores.append(sum_rel)

    return top_indices[torch.argmax(torch.tensor(relevance_scores)).item()]


def generate_contex_aware_response(input_ids, relevance_mask, args, cad=None):
    input_len = input_ids.input_ids.shape[1]
    if not cad:
        cad = CAD(model_name=args.model_name, device=device)
    cad.tokenizer.pad_token = cad.tokenizer.eos_token
    outputs = cad.generate(
        input_ids,
        relevance_mask,
        max_length=args.response_len,
        alpha=1,
        decoding_strategy="top_p",
        top_p_value=0.9,
        use_repetition_penalty=False,
        # repetition_penalty_value=1.5,
        args=args,
    )
    context_aware_response = cad.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {
        "response": context_aware_response,
        "length_response": len(outputs[0]),
        "length_full": input_len + len(outputs[0]),
    }


def select_best_token(model, input_embeds, last_token_logits, top_probs, top_indices, relevance_mask, args):
    token_count = 0
    if len(top_indices) == 1:  # no need to compare if only one token
        return top_indices[0], token_count

    relevance_scores = []
    token_ids = []
    for token_id in top_indices:
        token_logit = last_token_logits[token_id]
        token_logit.backward(retain_graph=True)
        token_count += 1

        if args.explanation_method == "attnLRP_gradcam":
            attn_weights = model.model.layers[-1].self_attn.attn_weights
            relevance = (attn_weights.grad * attn_weights)[:, :, -1].float().sum(1)
            relevance = relevance.clamp(min=0)  # keep only positive relevance, as sum could be negative
            attn_weights.grad = None
        else:
            relevance = (input_embeds.grad * input_embeds).float()
            if args.clamp:
                relevance = relevance.clamp(min=0)
            relevance = relevance.sum(-1)

        input_embeds.grad.zero_()
        relevance_sum = relevance.sum()
        try:
            assert relevance_sum > 0 and not torch.isclose(
                relevance_sum, torch.tensor(0.0)
            ), "Relevance sum is not positive enough"
        except AssertionError:
            break

        context_relevance = relevance[0, relevance_mask == 1]
        if args.agg_method == "sum":
            context_relevance = context_relevance.sum()
        elif args.agg_method == "max":
            context_relevance = context_relevance.max()
        else:
            raise ValueError(f"Unsupported aggregation method: {args.agg_method}")

        relevance_scores.append(context_relevance.cpu().item())
        token_ids.append(token_id)

    if len(relevance_scores) == 0:
        return top_indices[0], token_count  # if no relevance scores were computed
    if (
        args.rel_diff_threshold > 0.0
        and torch.tensor(relevance_scores).max() - relevance_scores[0] <= args.rel_diff_threshold
    ):
        return (
            top_indices[0],
            token_count,
        )  # if the best relevance score is not significantly better than the first token

    relevance_probs = torch.nn.functional.softmax(torch.tensor(relevance_scores), dim=0)
    scores = args.alpha * top_probs[: len(relevance_probs)].cpu() + (1 - args.alpha) * relevance_probs
    return token_ids[torch.argmax(scores).item()], token_count


def generate_relevance_response(
    tokenizer, stop_token_id, model, input_ids, args, relevance_mask=None, attn_indexes=None
):
    input_len = input_ids.shape[1]
    token_count_stats = 0

    for token_pos in range(args.response_len):
        if token_pos == args.response_len - 1:
            print("Achieved max response length")
        if token_pos > 0 and token_pos % 300 == 0:
            print(f"Processing token position {token_pos}")
        if args.cut_off_len is not None and token_pos >= args.cut_off_len:
            print(f"Cut off length reached at token position {token_pos}")
            with torch.no_grad():
                relevance_output = generate_api(tokenizer, model, input_ids, input_len + args.response_len)
            input_ids = relevance_output.sequences
            break

        input_embeds = model.get_input_embeddings()(input_ids)
        output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False, do_sample=False).logits
        last_token_logits = output_logits[:, -1]
        if args.repetition_penalty is not None and args.repetition_penalty != 1.0:
            last_token_logits = apply_repetition_penalty(
                last_token_logits,
                input_ids,
                args.repetition_penalty,
                input_len if args.repetition_penalty_ignore_prompt else None,
            )
        probs = torch.nn.functional.softmax(last_token_logits[0], dim=-1)
        if relevance_mask is not None and token_pos > 0:
            relevance_mask = torch.cat([relevance_mask, torch.tensor([0], dtype=torch.int8)])

        if args.entropy > 0.0 and calculate_entropy(probs) < args.entropy:
            # print(f"Entropy is too low at token position {token_pos}, applying greedy decoding")
            best_token_id = torch.argmax(probs)

        else:
            if args.filtering_strategy in ("min_p", "min_p_scaled"):
                top_probs, top_indices = torch.topk(probs, args.n_logits)
                keep_mask = top_probs >= (args.p if args.filtering_strategy == "min_p" else args.p * top_probs.max())
                keep_mask[0] = True  # always keep the top token
                top_probs, top_indices = top_probs[keep_mask], top_indices[keep_mask]
            elif args.filtering_strategy == "top_p":
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = sorted_probs.cumsum(dim=-1)
                keep_mask = cumulative_probs <= (1 - args.p)
                keep_mask[0] = True  # always keep the top token
                filtered_indices = sorted_indices[keep_mask]
                filtered_probs = probs[filtered_indices]
                top_probs, top_indices = filtered_probs[: args.n_logits], filtered_indices[: args.n_logits]
            else:
                raise ValueError(f"Unsupported decoding strategy: {args.decoding_strategy}")
            # Scale the selected probabilities to sum to 1
            top_probs = top_probs / top_probs.sum()

            if relevance_mask is not None and attn_indexes is not None:
                best_token_id = select_best_token_heads(
                    model,
                    input_embeds,
                    output_logits[0, -1],
                    top_probs,
                    top_indices,
                    attn_indexes,
                    args,
                    relevance_mask=relevance_mask,
                )
            elif relevance_mask is not None:
                best_token_id, token_count = select_best_token(
                    model,
                    input_embeds,
                    output_logits[0, -1],
                    top_probs,
                    top_indices,
                    relevance_mask,
                    args,
                )
                token_count_stats += token_count
            elif attn_indexes is not None:
                best_token_id = select_best_token_heads(
                    model, input_embeds, output_logits[0, -1], top_probs, top_indices, attn_indexes, args
                )
            else:
                raise ValueError("Relevance mask or attention indexes must be provided for relevance generation.")

        new_token = einops.rearrange(best_token_id, "-> 1 1")
        input_ids = torch.cat([input_ids, new_token], dim=1)
        if new_token == stop_token_id:
            print(f"Reached end of sequence at token position {token_pos}")
            break

    relevance_response = tokenizer.decode(input_ids[0][input_len:], skip_special_tokens=True)
    return {
        "response": relevance_response,
        "length_response": token_pos + 1,
        "length_full": input_ids.shape[1],
        "avg_tokens_considered": round(token_count_stats / (token_pos + 1), 2) if token_count_stats > 0 else 0,
    }


# === Base and Derived Classes ===
class BaseDatasetProcessor:
    def __init__(self, args, tokenizer, model, model_family, device):
        self.args = args
        self.tokenizer = tokenizer
        self.stop_token_id = (
            tokenizer.eos_token_id if model_family != "gemma3" else tokenizer.convert_tokens_to_ids("<end_of_turn>")
        )
        self.model = model
        self.model_family = model_family
        self.device = device

    def load_dataset(self, dataset_path, results_dir):
        with open(os.path.join(results_dir, dataset_path), "r") as f:
            dataset = [json.loads(line) for line in f]
        return dataset

    def prepare_chats(self, example):
        """
        Returns a list of tuples: (chat, extra_info) where extra_info is used in get_relevance_mask.
        """
        raise NotImplementedError

    def get_relevance_mask(self, input_tokenized, extra_info):
        """
        Computes the relevance mask from the tokenized prompt.
        """
        raise NotImplementedError

    def process_chat_request(self, chat, extra_info, cad=None):
        prompt = self.tokenizer.apply_chat_template(
            chat, add_generation_prompt=True, tokenize=False, return_tensors="pt"
        )
        if self.model_family == "llama3":
            if self.args.date_block == "remove":
                prompt = remove_date_block(prompt)
            elif self.args.date_block == "keep_no_relevance":
                extra_info["tokenized_date_block"] = self.tokenizer.tokenize(
                    re.search(DATE_BLOCK_PATTERN, prompt).group(0)
                )

        input_ids = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to(self.device)

        responses_dict = {}
        for method in self.args.generation_methods:
            if method == "relevance":
                if self.args.task_type != "heads":
                    input_tokenized = self.tokenizer.tokenize(prompt, add_special_tokens=False)
                    relevance_mask = self.get_relevance_mask(input_tokenized, extra_info)
                    try:
                        assert (
                            any(relevance_mask) and relevance_mask is not None
                        ), "Constraint not found in tokenized prompt"
                    except AssertionError:
                        print(f"Failed instruction: {prompt}")
                        raise
                responses_dict[method] = generate_relevance_response(
                    self.tokenizer,
                    self.stop_token_id,
                    self.model,
                    input_ids.input_ids,
                    self.args,
                    relevance_mask=relevance_mask if self.args.task_type != "heads" else None,
                    attn_indexes=self.attn_indexes if self.args.task_type != "context" else None,
                )
            elif method == "greedy":
                responses_dict[method] = generate_greedy_response(
                    self.tokenizer, self.model, self.model_family, input_ids, self.args
                )
            elif method == "random":
                responses_dict[method] = generate_nucleus_response(
                    self.tokenizer, self.model, self.model_family, input_ids, self.args
                )
            elif method == "context_aware":
                responses_dict[method] = generate_contex_aware_response(input_ids, relevance_mask, self.args, cad)
            else:
                raise ValueError(f"Unsupported generation method: {method}")
        return responses_dict

    def write_results(self, file_handles, example, responses):
        """
        Write the results for a given example. This method should be overridden by subclasses
        to format the output accordingly.
        """
        raise NotImplementedError


class HotPotQAProcessor(BaseDatasetProcessor):
    def load_dataset(self, dataset_path, results_dir):
        if self.args.task_type != "context":
            with open(self.args.heads_path, "r") as f:
                line = f.readline()
            self.attn_indexes = [int(num.strip()) for num in line.split(",")]
        with open(os.path.join(results_dir, dataset_path), "r") as f:
            dataset = [json.loads(line) for line in f]
        return dataset

    def get_relevance_mask(self, input_tokenized, extra_info):
        if self.args.heads_type != "parametric":
            return get_mask_by_text(input_tokenized, self.tokenizer.tokenize(preprocess_ctx(extra_info["context"])))

    def prepare_chats(self, example):
        return [
            [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": (
                        example["question"]
                        if self.args.heads_type == "parametric"
                        else f"{preprocess_ctx(example['context'])} \n\nBased on this text, answer this question:\nQ: {example['question']}\nA:"
                    ),
                },
            ]
        ], ({"context": example["context"]} if self.args.heads_type != "parametric" else {})

    def write_results(self, file_handles, example, responses):
        for method in self.args.generation_methods:
            response_dict = responses[0][method]
            out_obj = {
                "prompt": example["question"],
                "context": example.get("context", ""),
                "response": response_dict["response"],
                "answer": example["answer"],
                "length_response": response_dict["length_response"],
                "length_full": response_dict["length_full"],
                "avg_tokens_considered": response_dict.get("avg_tokens_considered", None),
            }
            file_handle = file_handles.get(method)
            if file_handle:
                file_handle.write(json.dumps(out_obj) + "\n")


class TRIVIAQAProcessor(BaseDatasetProcessor):
    def load_dataset(self, dataset_path, results_dir):
        if self.args.task_type != "context":
            with open(self.args.heads_path, "r") as f:
                line = f.readline()
            self.attn_indexes = [int(num.strip()) for num in line.split(",")]
        with open(os.path.join(results_dir, dataset_path), "r") as f:
            dataset = [json.loads(line) for line in f]
        return dataset

    def prepare_chats(self, example):
        return [
            [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": (
                        example["qas"][0]["question"]
                        if self.args.heads_type == "parametric"
                        else f"{preprocess_ctx(example['context'])} \n\nBased on this text, answer this question:\nQ: {example['qas'][0]['question']}\nA:"
                    ),
                },
            ]
        ], ({"context": example["context"]} if self.args.heads_type != "parametric" else {})

    def get_relevance_mask(self, input_tokenized, extra_info):
        if self.args.heads_type != "parametric":
            return get_mask_by_text(input_tokenized, self.tokenizer.tokenize(preprocess_ctx(extra_info["context"])))

    def write_results(self, file_handles, example, responses):
        for method in self.args.generation_methods:
            response_dict = responses[0][method]
            out_obj = {
                "prompt": example["qas"][0]["question"],
                "context": example.get("context", ""),
                "response": response_dict["response"],
                "answer": example["qas"][0]["answers"],
                "length_response": response_dict["length_response"],
                "length_full": response_dict["length_full"],
                "avg_tokens_considered": response_dict.get("avg_tokens_considered", None),
            }
            file_handle = file_handles.get(method)
            if file_handle:
                file_handle.write(json.dumps(out_obj) + "\n")


class NQProcessor(BaseDatasetProcessor):
    def load_dataset(self, dataset_path, results_dir):
        if self.args.task_type != "context":
            with open(self.args.heads_path, "r") as f:
                line = f.readline()
            self.attn_indexes = [int(num.strip()) for num in line.split(",")]
        with open(os.path.join(results_dir, dataset_path), "r") as f:
            dataset = [json.loads(line) for line in f]
        return dataset

    def get_relevance_mask(self, input_tokenized, extra_info):
        if self.args.heads_type != "parametric":
            return get_mask_by_text(input_tokenized, self.tokenizer.tokenize(preprocess_ctx(extra_info["context"])))

    def prepare_chats(self, example):
        return [
            [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": (
                        example["question"]
                        if self.args.heads_type == "parametric"
                        else f"{preprocess_ctx(example['context'])} \n\nBased on this text, answer this question:\nQ: {example['question']}\nA:"
                    ),
                },
            ]
        ], ({"context": example["context"]} if self.args.heads_type != "parametric" else {})

    def write_results(self, file_handles, example, responses):
        for method in self.args.generation_methods:
            response_dict = responses[0][method]
            out_obj = {
                "question": example["question"],
                "context": example.get("context", ""),
                "response": response_dict["response"],
                "answer": example["answer"],
                "length_response": response_dict["length_response"],
                "length_full": response_dict["length_full"],
                "avg_tokens_considered": response_dict.get("avg_tokens_considered", None),
            }
            file_handle = file_handles.get(method)
            if file_handle:
                file_handle.write(json.dumps(out_obj) + "\n")


class IFEVALProcessor(BaseDatasetProcessor):
    def prepare_chats(self, example):
        prompt = example["prompt"]
        constraint_tokenized = example["constraint"]
        # constraint_tokenized = tokenizer.tokenize(constraint) # depends on the input format
        # constraint_tokenized = ast.literal_eval("[" + constraint + "]")
        return [[{"role": "user", "content": prompt}]], {"constraint_tokenized": constraint_tokenized}

    def get_relevance_mask(self, input_tokenized, extra_info):
        constraint_tokenized = extra_info["constraint_tokenized"]
        return get_mask_by_text(input_tokenized, constraint_tokenized)

    def write_results(self, file_handles, example, responses):
        for method in self.args.generation_methods:
            response_dict = responses[0][method]
            out_obj = {
                "prompt": example["prompt"],
                "response": response_dict["response"],
                "length_response": response_dict["length_response"],
                "length_full": response_dict["length_full"],
                "avg_tokens_considered": response_dict.get("avg_tokens_considered", None),
            }
            file_handle = file_handles.get(method)
            if file_handle:
                file_handle.write(json.dumps(out_obj) + "\n")


class IFEVALSimpleProcessor(IFEVALProcessor):
    def prepare_chats(self, example):
        prompt = example["prompt"]
        constraint_tokenized = self.tokenizer.tokenize(example["constraint"].removeprefix(" \n"))  # tokenization issue
        return [[{"role": "user", "content": prompt}]], {"constraint_tokenized": constraint_tokenized}

    def get_relevance_mask(self, input_tokenized, extra_info):
        return (
            super().get_relevance_mask(input_tokenized, extra_info)
            if self.args.relevance_role != "full"
            else get_mask_by_role(input_tokenized, "user", self.model_family)
        )


class IHEVALProcessor(BaseDatasetProcessor):
    def load_dataset(self, dataset_path, results_dir):
        if self.args.task_type != "context":
            with open(self.args.heads_path, "r") as f:
                line = f.readline()
            self.attn_indexes = [int(num.strip()) for num in line.split(",")]
        return super().load_dataset(dataset_path, results_dir)

    def prepare_chats(self, example):
        if "mt_aligned" not in self.args.dataset_path:
            return [
                [{"role": "system", "content": example["system"]}, {"role": "user", "content": example["prompt"]}]
            ], ({} if self.model_family != "gemma3" else {"system": example["system"]})
        else:
            return [
                [
                    {"role": "system", "content": example["system"]},
                    {"role": "user", "content": example["conversation_history"][0]},
                    {"role": "assistant", "content": example["conversation_history"][1]},
                    {"role": "user", "content": example["prompt"]},
                ]
            ], {}

    def get_relevance_mask(self, input_tokenized, extra_info):
        return (
            get_mask_by_role(input_tokenized, "system", self.model_family)
            if self.model_family != "gemma3"
            else get_mask_by_text(input_tokenized, self.tokenizer.tokenize(extra_info["system"]))
        )

    def write_results(self, file_handles, example, responses):
        for method in self.args.generation_methods:
            response_dict = responses[0][method]
            out_obj = {
                "system": example["system"],
                "prompt": example["prompt"] if self.args.relevance_role != "full" else example["original_prompt"],
                "response": response_dict["response"],
                "length_response": response_dict["length_response"],
                "length_full": response_dict["length_full"],
                "avg_tokens_considered": response_dict.get("avg_tokens_considered", None),
            }
            file_handle = file_handles.get(method)
            if file_handle:
                file_handle.write(json.dumps(out_obj) + "\n")


def get_dataset_processor(dataset_path, args, tokenizer, model, model_family, device):
    processors = {
        "triviaqa": TRIVIAQAProcessor,
        "nq": NQProcessor,
        "hotpotqa": HotPotQAProcessor,
        "ifeval_simple": IFEVALSimpleProcessor,
        "ifeval": IFEVALProcessor,
        "iheval": IHEVALProcessor,
    }
    for key, processor in processors.items():
        if key in dataset_path:
            return processor(args, tokenizer, model, model_family, device)
    raise ValueError("Dataset not implemented")
