import os
import sys
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '.', '.'))
sys.path.append(project_root)
from Craftax.craftax.lm import LanguageModel
from Craftax.craftax.craftax_classic.constants import *
from Craftax.craftax.craftax_classic.model import CREATE_SBERT_MODEL
import numpy as np
import jax.numpy as jnp
import random
from Craftax.craftax.craftax_classic.encoder import SbertEncoder
from Craftax.craftax.craftax_classic.best_goal import AdvancedGoalGenerator
from Craftax.craftax.craftax_classic.nopri_goal import BestGoal
import torch
from Craftax.craftax.craftax_classic.cache.cache import Cache

device = torch.device('cuda')

FEASIBLE_GOALS = [
    "eat plant",
    "attack zombie",
    "attack skeleton",
    "attack cow",
    "chop tree",
    "mine stone",
    "mine coal",
    "mine iron",
    "mine diamond",
    "drink water",
    "chop grass",
    "sleep",
    "place stone",
    "place crafting table",
    "place furnace",
    "place plant",
    "make wood pickaxe",
    "make stone pickaxe",
    "make iron pickaxe",
    "make wood sword",
    "make stone sword",
    "make iron sword",
    "plant row",
    "chop grass with wood pickaxe",
    "vegetarianism",
    "make workshop",
    "survival",
    "deforestation",
    "work and sleep",
    "gardening",
    "wilderness survival"
]
ACTION = ['noop', 'left', 'right', 'up', 'down', 'do', 'sleep', 'place_stone', 'place_table', 'place_furnace',
          'place_plant', 'make_wood_pickaxe', 'make_stone_pickaxe', 'make_iron_pickaxe', 'make_wood_sword',
          'make_stone_sword', 'make_iron_sword']


class PixelsTextGoal:
    def __init__(self, num_envs, goal_type):
        self.goal_type = goal_type
        print(f"Goal: {self.goal_type}")
        self.best_goal = AdvancedGoalGenerator()
        self.state = None
        self.reset = True
        self.tokenizer = CREATE_SBERT_MODEL().get_sbert_tokenizer()
        self.embedding_cache = {}
        self.lang_goal_encoder = SbertEncoder(512, device)
        self.lang_goal_encoder.to(device)
        self.previous_text_state = []
        self.previous_goal = []
        self.previous_long_goals = []
        self.text_describ = self.render_craftax_text_describ
        self.goal_pool = []
        self.get_goal_times = 0
        self.cache_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / 'cache/goal_mask.jsonl'
        print(self.cache_path)
        self.cache = Cache(self.cache_path, num_envs)
    def render_craftax_text_describ(self, view_arr, index):
        (map_view, mob_map, inventory_values, status_values) = view_arr

        mob_id_to_name = ["zombie", "cows", "skeletons", "arrows"]
        block_id_to_name = ["invalid", "out of bounds", "grass", "water", "stone", "tree", "wood", "path", "coal",
                            "iron", "diamond", "crafting table", "furnace", "sand", "lava", "plant", "ripe plant"]

        text_view_values = set()

        block_names = np.vectorize(lambda x: block_id_to_name[x])(map_view[index])
        text_view_values.update(block_names.flatten())

        mob_ids = np.argmax(mob_map[index], axis=-1)
        mob_names = np.vectorize(lambda x: mob_id_to_name[x])(mob_ids)
        mob_mask = mob_map[index].max(axis=-1) > 0.5
        text_view_values.update(mob_names[mob_mask].flatten())
        text_view = ", ".join(text_view_values)

        inv_names = [
            "wood", "stone", "coal", "iron", "diamond", "sapling",
            "wood pickaxe", "stone pickaxe", "iron pickaxe",
            "wood sword", "stone sword", "iron sword"
        ]
        text_obs_inv = ", ".join([
            f"{name}: {value}"
            for name, value in zip(inv_names, inventory_values[index])
            if value > 0
        ])

        status_names = [
            "Health", "Fullness", "Hydration", "Wakefulness", "Sky brightness level"
        ]
        status = ", ".join([
            f"{name}: {int(value / 0.09)}%"
            for name, value in zip(status_names, status_values[index])
        ])

        text_obs = "You see: " + text_view + "\nInventory: " + text_obs_inv + "\nStatus: " + status

        return text_obs


class GoalandMask(PixelsTextGoal):
    def __init__(self, num_envs, goal_type, lm_name, env_name, alg_name):
        super().__init__(num_envs, goal_type)
        self.lm = LanguageModel(lm_name=lm_name, env_name=env_name, alg_name=alg_name)
        self.get_goal = self.best_goal.get_augmentation_more_goal_text
        self.goal_nopri = BestGoal()
        self.get_goal_nopri = self.goal_nopri
        self.count = 0
        self.count_mask = 0

    def get_pixels_goals(self, view_arr, shape):
        text_states = [self.text_describ(view_arr, i) for i in range(shape)]
        goals = []
        masks =[]
        mask = np.zeros(17)
        for text_state in text_states:

            if self.goal_type == "SGRL_nopri":
                goal = self.get_goal_nopri.determine_goal(text_state)
                if self.cache.check_in_cache(goal):
                    mask = self.cache.retrieve_from_cache(goal)
                else:
                    mask = self.get_mask(text_state, goal)
                    self.cache.store_in_cache(goal, mask)
            else:
                goal, all_goal = self.get_goal(text_state)
                for g in goal:
                    if self.cache.check_in_cache(g['goal']):
                        m = self.cache.retrieve_from_cache(g['goal'])
                    else:
                        m = self.get_mask(g['goal'])
                        self.cache.store_in_cache(g['goal'], m)
                    mask += m
                mask[mask != 0] /= mask[mask != 0]
            masks.append(mask)
            goals.append(goal)
        if self.goal_type in ["Best3_and_Mask", "SGRL_nopri"]:
            goal_texts = goals
        else:
            goal_texts = [",".join(f"{g['goal']}: {g['priority']:.1f}" for g in goal) for goal in goals]


        inputs = self.tokenizer(goal_texts, return_tensors='pt', padding=True, truncation=True)
        input_ids = inputs['input_ids'].to(device)
        with torch.no_grad():
            text_embs = self.lang_goal_encoder(input_ids)

        return (jnp.array(text_embs.cpu().numpy(), dtype=np.float32),jnp.array(masks))

    def get_mask(self, text_state, goal):
        self.system_prompt = self.lm.prompt_format.get_system_prompt(
            system_prompt="/home/rd-111s/QYJ/code/Craftax_text/SGRL/Craftax/craftax/prompt/prompt_for_pruner.txt")
        self.user_prompt = f"Agent state: {text_state}, Agent have a goal: {goal}. List actions that are related to achieving the goal. Just list the actions."
        response = self.lm.try_query(self.system_prompt, self.user_prompt)
        mask = self.text_to_action_vector(response)
        return mask

    def text_to_action_vector(self, text):
        # Initialize a 17-dimensional vector with all zeros
        action_vector = [0] * len(ACTION)

        # Extract all words that could be actions (lowercase and remove special chars)
        words = text.lower().replace('.', '').replace(',', '').split()

        # Check each word against the action list
        for i, action in enumerate(ACTION):
            # Handle both "move left" and "move_left" formats
            action_word = action.replace('_', ' ')
            if action_word in ' '.join(words):
                action_vector[i] = 1
        if np.sum(action_vector) == 0:
            if action_vector < 4:
                self.count_mask += 1
                action_vector = self.text_to_action_vector(text)
            else:
                action_vector = [1] * len(ACTION)
                self.count_mask = 0
        print(f"text: {text}, action_vector: {action_vector}")
        return jnp.array(action_vector, dtype=np.floal32)

class LlmGoal(PixelsTextGoal):
    def __init__(self, lm_name, env_name, num_envs, alg_name, goal_type):
        super().__init__(num_envs, goal_type)
        self.lm = LanguageModel(lm_name=lm_name, env_name=env_name, alg_name=alg_name)
        self.get_goal = self.lm.predict_options
        self.alg_name = alg_name
        self.l_scores = [0]*num_envs
        # self.cache_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / 'cache/lm_goal_mask.jsonl'

    def compute_l_score(self, a, b):
        if not isinstance(a, torch.Tensor):
            a = torch.tensor(a)
        if not isinstance(b, torch.Tensor):
            b = torch.tensor(b)

        if len(a.shape) == 1:
            a = a.unsqueeze(0)
        if len(b.shape) == 1:
            b = b.unsqueeze(0)

        a_norm = torch.nn.functional.normalize(a, p=2, dim=1)  # (256, 384)
        b_norm = torch.nn.functional.normalize(b, p=2, dim=1)  # (256, 384)

        similarity = (a_norm * b_norm).sum(dim=1)
        scores = 1 / (1+torch.exp(-10 * (similarity-0.1)))
        return scores.detach().cpu().tolist()

    def get_pixels_goals(self, view_arr, shape):
        text_states = [self.text_describ(view_arr, i) for i in range(shape)]
        goal_texts = [self.get_goal(text_state, l_score) for text_state,l_score in zip(text_states, self.l_scores)]
        inputs = self.tokenizer(goal_texts, return_tensors='pt', padding=True, truncation=True)
        input_ids = inputs['input_ids'].to(device)
        with torch.no_grad():
            goal_embs = self.lang_goal_encoder(input_ids)
        if self.alg_name == 'AdaRefiner':
            with torch.no_grad():
                inputs = self.tokenizer(text_states, return_tensors='pt', padding=True, truncation=True)
                input_ids = inputs['input_ids'].to(device)
                state_embs = self.lang_goal_encoder(input_ids) # (256,384)
            self.l_scores = self.compute_l_score(jnp.array(goal_embs.cpu().numpy(), dtype=np.float32),jnp.array(state_embs.cpu().numpy(), dtype=np.float32))
        return jnp.array(goal_embs.cpu().numpy(), dtype=np.float32)







