import re

import torch
import numpy as np
from transformers import GenerationConfig
from saycanpay.lm_planner_base import CustomScorer
from saycanpay.data_utils import process_action

from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn


class SayCanPayAgent(BaseAgent, FewShotMixIn):
    name = "saycanpay"

    def __init__(self, env_name, llm, max_steps=8):
        super().__init__()

        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)

        self.model = self._prepare_model(llm)
        self.few_shot_prompt = None

        self.max_steps = max_steps
        self.num_final_beams = 3
        self.num_return_sequences = 6
        self.custom_scorer = None
        self.beam_generation_config = GenerationConfig(
            max_new_tokens=10,
            num_beams=6,
            num_beam_groups=3,
            num_return_sequences=self.num_return_sequences,
            diversity_penalty=2.0,
            pad_token_id=self.model.tokenizer.pad_token_id,
            output_scores=True,
            return_dict_in_generate=True,
        )

        self.plan = None
        self.goal = None
        self.admissible_actions = self._load_admissible_actions(env_name)

    def get_action(self, obs):
        if self.plan:
            return self.plan.pop(0)
        elif self.plan == []:
            return "done task"

        prompt = f"{self.prompt}\n[Goal] {self.goal} [Initial State] {obs} [Step 1] "
        context_can = f"[Goal] {self.goal} [Initial State] {obs}"
        context_pro = f"[Goal] {self.goal} [Initial State] {obs}"

        generated_sequences, generated_beam_scores = self._beam_generate(prompt)
        generated_sequences = [
            process_action(seq, self.admissible_actions) for seq in generated_sequences
        ]
        generated_sequences, generated_beam_scores = self._select_unique_steps(
            generated_sequences, generated_beam_scores
        )

        all_beams = []
        self.num_final_beams = 3

        action_beam_prompts = [prompt for _ in range(1)]
        action_beam_scores = [0 for _ in range(1)]
        contexts_can = [context_can for _ in range(1)]
        contexts_pro = [context_pro for _ in range(1)]

        action_steps = 0
        while True:
            action_steps += 1
            if action_steps == 1:
                (
                    action_beam_scores,
                    action_beam_prompts,
                    done,
                    contexts_can,
                    contexts_pro,
                    _,
                ) = self.custom_scorer.filter_beams(
                    np.reshape(
                        generated_sequences, (1, len(generated_sequences))
                    ).tolist(),
                    generated_beam_scores.view(1, len(generated_sequences)),
                    action_beam_prompts,
                    action_beam_scores,
                    contexts_can,
                    contexts_pro,
                    self.goal,
                    action_steps,
                    task_completed_desc="done task",
                    num_final_beams=self.num_final_beams,
                )
            else:
                (
                    action_beam_scores,
                    action_beam_prompts,
                    done,
                    contexts_can,
                    contexts_pro,
                    _,
                ) = self.custom_scorer.filter_beams(
                    generated_sequences,
                    generated_beam_scores,
                    action_beam_prompts,
                    action_beam_scores,
                    contexts_can,
                    contexts_pro,
                    self.goal,
                    action_steps,
                    task_completed_desc="done task",
                    num_final_beams=self.num_final_beams,
                )

            for ind, done_flag in enumerate(done):
                if done_flag:
                    action_beam_prompt = action_beam_prompts[ind]
                    action_beam_score = action_beam_scores[ind]
                    gen_plan = self._postprocess_plans_from_beams(action_beam_prompt)
                    gen_plan_score = action_beam_score / action_steps
                    all_beams.append((gen_plan, gen_plan_score))
                    self.num_final_beams -= 1
            try:
                (
                    action_beam_prompts,
                    action_beam_scores,
                    done,
                    contexts_can,
                    contexts_pro,
                ) = list(
                    zip(
                        *[
                            x
                            for ind, x in enumerate(
                                zip(
                                    action_beam_prompts,
                                    action_beam_scores,
                                    done,
                                    contexts_can,
                                    contexts_pro,
                                )
                            )
                            if not done[ind]
                        ]
                    )
                )
            except ValueError:
                pass

            if self.num_final_beams == 0:
                break

            if action_steps > self.max_steps:
                if len(all_beams) > 0:
                    self.plan = sorted(all_beams, key=lambda x: x[1], reverse=True)[0][
                        0
                    ]
                    return self.plan.pop(0)
                return None

            new_generated_sequences, new_generated_beam_scores = self._beam_generate(
                action_beam_prompts
            )
            new_generated_sequences = [
                process_action(seq, self.admissible_actions)
                for seq in new_generated_sequences
            ]
            new_generated_sequences = np.reshape(
                new_generated_sequences,
                (self.num_final_beams, self.num_return_sequences),
            ).tolist()
            new_generated_beam_scores = new_generated_beam_scores.view(
                self.num_final_beams, self.num_return_sequences
            )
            generated_sequences = []
            generated_beam_scores = []
            for gen_id, gen_seqs in enumerate(new_generated_sequences):
                gen_beam_scores = new_generated_beam_scores[gen_id]
                unique_gen_seqs, unique_gen_beam_scores = self._select_unique_steps(
                    gen_seqs, gen_beam_scores
                )
                generated_sequences.append(unique_gen_seqs)
                generated_beam_scores.append(unique_gen_beam_scores)

        self.plan = sorted(all_beams, key=lambda x: x[1], reverse=True)[0][0]
        return self.plan.pop(0)

    def reset(self, task, goal):
        assert self.few_shot_prompt, "Make sure to call load_few_shot_prompt() first."
        assert self.custom_scorer, "Make sure to call load_can_pay_model() first."

        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are two examples.\n\n"
            + self.few_shot_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )
        self.plan = None
        self.goal = goal
        self.num_final_beams = 3

    def _beam_generate(self, prompt):
        if not isinstance(prompt, str):
            prompt_length = len(prompt[0])
        else:
            prompt_length = len(prompt)

        device = self.model.device
        torch_dtype = self.model.torch_dtype

        if torch_dtype is torch.bfloat16:
            self.model.model.float()

        inputs = self.model.tokenizer(prompt, return_tensors="pt", padding=True)
        inputs.input_ids = inputs.input_ids.to(device)
        inputs.attention_mask = inputs.attention_mask.to(device).to(torch_dtype)

        beam_search_output = self.model.model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            generation_config=self.beam_generation_config,
        )
        generated_ids, beam_scores = (
            beam_search_output.sequences,
            beam_search_output.sequences_scores,
        )
        generated_sequences = self.model.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        generated_sequences = [
            seq[prompt_length:].strip().split("[")[0].split("\n")[0].strip()
            for seq in generated_sequences
        ]
        return generated_sequences, beam_scores

    def _select_unique_steps(self, generated_sequences, beam_scores):
        gen_score_dict = {}
        for gen_seq, score in zip(generated_sequences, beam_scores):
            if gen_seq in gen_score_dict:
                gen_score_dict[gen_seq] = max(gen_score_dict[gen_seq], score)
            else:
                gen_score_dict[gen_seq] = score
        generated_sequences_processed = list(gen_score_dict.keys())
        beam_scores_processed = torch.tensor(
            list(gen_score_dict.values()), device=beam_scores.device
        ).view(-1)
        while 0 < len(generated_sequences_processed) < self.num_final_beams:
            generated_sequences_processed.append(generated_sequences_processed[-1])
            beam_scores_processed = torch.cat(
                (beam_scores_processed, beam_scores_processed[-1].view(-1))
            )

        return generated_sequences_processed, beam_scores_processed

    def _postprocess_plans_from_beams(self, beam_prompt):
        beam_prompt = beam_prompt.split("Here is the task.")[1].strip(" \n")

        pattern = r"\[Step \d+\](.*?)(?=\[Step \d+\]|$)"
        steps = re.findall(pattern, beam_prompt, re.DOTALL)
        retrieved_plan = [step.strip() for step in steps if step]
        return retrieved_plan

    def _load_admissible_actions(self, env_name):
        admissible_action_path = (
            f"externals/saycanpay/{env_name}_admissible_actions.txt"
        )
        with open(admissible_action_path, "r") as f:
            admissible_actions = f.read().splitlines()
        admissible_actions.append("done task")
        return admissible_actions

    def forward(
        self,
        instruction: str,
        state: str,
        history: str,
        few_shot_examples=None,
    ):
        return self.get_action(state)

    def load_can_pay_model(self, checkpoint_dir, decoding_score):
        self.custom_scorer = CustomScorer(checkpoint_dir, decoding_score)
