import copy
import json
import logging
import os
import threading
from typing import Any, Dict, List

import shortuuid
from omegaconf import DictConfig
from thefuzz import process
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

# from .oracle_graph import KnowledgeGraph
from ..utils import language_action_to_subgoal

import fcntl
import time

def read_json_data_with_shared_lock(file_path):
    with open(file_path, "r") as fp:
        fcntl.flock(fp, fcntl.LOCK_SH)
        data = json.load(fp)
        fcntl.flock(fp, fcntl.LOCK_UN)
    return data



# wp_to_sg_dir_path
# wp_to_sg_hypothesized_recipe_dir

class DecomposedMemory:
    current_environment: str = ""
    _lock: threading.Lock = threading.Lock()

    def __init__(
        self,
        cfg: DictConfig,
        logger: logging.Logger | None = None,
        life_long_learning: bool = False,
    ) -> None:
        self.cfg = cfg
        self.version = self.cfg["version"]

        self.logger = logger

        self.plan_failure_threshold = int(cfg["memory"]["plan_failure_threshold"])

        self.root_path = self.cfg["memory"]["path"]
        self.logger.info(f"DecomposedMemory root path: {self.root_path}")
        os.makedirs(self.root_path, exist_ok=True)

        self.wp_to_sg_dir_path = os.path.join(
            self.root_path, self.cfg["memory"]["waypoint_to_sg"]["path"]
        )
        self.logger.info(f"DecomposedMemory wp_to_sg_dir_path: {self.wp_to_sg_dir_path}")
        os.makedirs(self.wp_to_sg_dir_path, exist_ok=True)

        self.device = f'cuda:{cfg["device_id"]}'
        self.bert_encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)

        self._prepare_waypoint_embeddings()
        self._prepare_succeeded_waypoints()

        # self.crafting_graph = KnowledgeGraph( # Just for backward compatibility
        #     life_long_learning=life_long_learning,
        # )

    def _prepare_waypoint_embeddings(self):
        lst_dir = os.listdir(self.wp_to_sg_dir_path)
        with self._lock:
            for file_name in lst_dir:
                if not file_name.endswith(".json"):
                    continue

                waypoint = file_name.replace(".json", "")
                wp_embed = torch.tensor(self.bert_encoder.encode(waypoint)).unsqueeze(0).to(self.device)
                torch.save(wp_embed, os.path.join(self.wp_to_sg_dir_path, f"{waypoint}.pt"))
                # print(f'saved {os.path.join(self.wp_to_sg_dir_path, f"{waypoint}.pt")}')

    def _prepare_succeeded_waypoints(self):
        # get succeeded waypoint (item names) into self.succeeded_waypoints
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        self.succeeded_waypoints = []

        for file_name in lst_dir:
            if not file_name.endswith(".json"):
                continue

            waypoint = file_name.replace(".json", "")
            is_succeeded, _ = self.is_succeeded_waypoint(waypoint)
            if is_succeeded:
                self.succeeded_waypoints.append(waypoint)

    def retrieve_all_reflections(self, item_name):
        # return list of dictionaries
        json_file_name = f"{item_name}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return []
        
        reflections = []

        wp_file_data = read_json_data_with_shared_lock(os.path.join(self.wp_to_sg_dir_path, json_file_name))

        for action, action_history in wp_file_data['action'].items():
            # check empty dictionary
            if 'reflection' in action_history.keys() and action_history['reflection'] is not None and len(action_history['reflection']) > 0:
                reflections.append(action_history['reflection'])

        return reflections


    def save_reflection(self, item_name, action_str, inventory_before_action, reflection):
        json_file_name = f"{item_name}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return
        
        with self._lock:
            with open(os.path.join(self.wp_to_sg_dir_path, json_file_name), "r+") as fp:
                fcntl.flock(fp, fcntl.LOCK_EX)
                wp_file_data = json.load(fp)

                if action_str in wp_file_data['action'].keys():
                    wp_file_data['action'][action_str]['reflection'] = {
                        "item_name": item_name,
                        "inventory": inventory_before_action,
                        "plan": action_str,
                        "failure_analysis": reflection
                    }
                else:
                    wp_file_data['action'][action_str] = {'success': 0, 'failure': 0}
                    wp_file_data['action'][action_str]['reflection'] = {
                        "item_name": item_name,
                        "inventory": inventory_before_action,
                        "plan": action_str,
                        "failure_analysis": reflection
                    }

                fp.seek(0)
                fp.truncate()
                json.dump(wp_file_data, fp)
                fcntl.flock(fp, fcntl.LOCK_UN)


    def get_history_of_action(self, item_name, action_str):
        json_file_name = f"{item_name}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return {}
        
        wp_file_data = read_json_data_with_shared_lock(os.path.join(self.wp_to_sg_dir_path, json_file_name))

        if action_str in wp_file_data['action'].keys():
            return wp_file_data['action'][action_str]
        else:
            return {}


    def save_success_failure(self, waypoint, action_str, is_success):
        self._save_success_failure(waypoint, action_str, is_success)

    def _save_success_failure(self, waypoint, action_str, is_success):

        json_file_name = f"{waypoint}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        new_success = False

        # There is no {waypoint}.json file
        if json_file_name not in lst_dir:
            with self._lock:
                with open(os.path.join(self.wp_to_sg_dir_path, json_file_name), "w+") as fp:
                    fcntl.flock(fp, fcntl.LOCK_EX)

                    wp_file_data = dict()
                    wp_file_data['action'] = dict()
                    wp_file_data['action'][action_str] = {'success': 1, 'failure': 0} if is_success else {'success': 0, 'failure': 1}
                    new_success = is_success

                    fp.seek(0)
                    fp.truncate()
                    json.dump(wp_file_data, fp, indent=2)

                    fcntl.flock(fp, fcntl.LOCK_UN)

                wp_embed = torch.tensor(self.bert_encoder.encode(waypoint)).unsqueeze(0).to(self.device)

                torch.save(wp_embed, os.path.join(self.wp_to_sg_dir_path, f"{waypoint}.pt"))

            if new_success:
                self._prepare_waypoint_embeddings()
            self._prepare_succeeded_waypoints()

            return
        
        # There is {waypoint}.json file
        with self._lock:
            with open(os.path.join(self.wp_to_sg_dir_path, json_file_name), "r+") as fp:
                fcntl.flock(fp, fcntl.LOCK_EX)
                wp_file_data = json.load(fp)

                if action_str in wp_file_data['action'].keys():
                    if is_success:
                        wp_file_data['action'][action_str]['success'] += 1
                    else:
                        wp_file_data['action'][action_str]['failure'] += 1
                else:
                    wp_file_data['action'][action_str] = {'success': 1, 'failure': 0} if is_success else {'success': 0, 'failure': 1}

                fp.seek(0)
                fp.truncate()
                json.dump(wp_file_data, fp)
                fcntl.flock(fp, fcntl.LOCK_UN)

            if wp_file_data['action'][action_str]['success'] == 1 and is_success:
                new_success = True

        if new_success:
            self._prepare_waypoint_embeddings()
        self._prepare_succeeded_waypoints()

        return


    def is_succeeded_waypoint(self, waypoint):
        json_file_name = f"{waypoint}.json"
        # self.logger.info(f"In is_succeeded_waypoint(): waypoint: {waypoint}, json_file_name: {json_file_name}")
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return False, None
        
        if not os.path.exists(os.path.join(self.wp_to_sg_dir_path, json_file_name)):
            return False, None

        succeeded_action_lists = []

        wp_file_data = read_json_data_with_shared_lock(os.path.join(self.wp_to_sg_dir_path, json_file_name))

        for action, action_history in wp_file_data['action'].items():
            if action_history['success'] > 0 and (action_history['success'] - action_history['failure']) > -self.plan_failure_threshold:
                succeeded_action_lists.append([action, action_history['success'] - action_history['failure']])

        if len(succeeded_action_lists) > 0:
            succeeded_action_lists = sorted(succeeded_action_lists, key=lambda x: x[1], reverse=True)
            _, succeeded_subgoal_str = language_action_to_subgoal(succeeded_action_lists[0][0], waypoint)
            return True, succeeded_subgoal_str
        else:
            return False, None


    def retrieve_similar_succeeded_waypoints(self, waypoint, topK=3):
        # 1. for succeeded waypoint $wp^{success} \in M$, calculate $similarity(BERT^{text}(wp^{unseen}), BERT^{text}(wp^{success}))$.
        # 2. select top-k $wp^{success} \in M$ which are most similar to the $wp^{unseen}$.
        # 3. retrieve subgoals for the top-k $wp^{success}$, making $\{(wp_i^{success}, sg_i^{success})\}_{i=1}^k$.

        sorted_succeeded_waypoints = sorted(self.succeeded_waypoints)
        embedding_tensors = [torch.load(os.path.join(self.wp_to_sg_dir_path, f'{name}.pt')) for name in sorted_succeeded_waypoints]
        embedding_matrix = torch.cat(embedding_tensors, dim=0)
        embedding_matrix = embedding_matrix.to(self.device)

        # wp_embedding = self.mineclip.encode_text(waypoint)

        wp_embedding = torch.tensor(self.bert_encoder.encode(waypoint)).unsqueeze(0).to(self.device)

        cosine_similarities = torch.matmul(embedding_matrix, wp_embedding.T).squeeze()

        topK_values, topK_indices = torch.topk(cosine_similarities, topK)
        top_succeeded_waypoints = [sorted_succeeded_waypoints[i] for i in topK_indices.tolist()]

        wp_sg_dict = dict()

        for succeeded_wp in top_succeeded_waypoints:
            _, sg = self.is_succeeded_waypoint(succeeded_wp)
            wp_sg_dict[succeeded_wp] = sg

        return wp_sg_dict

    def retrieve_failed_subgoals(self, waypoint):
        json_file_name = f"{waypoint}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return []
        
        failed_subgoal_lists = []

        # with open(os.path.join(self.wp_to_sg_dir_path, json_file_name), "r") as fp:
        #     wp_file_data = json.load(fp)
        wp_file_data = read_json_data_with_shared_lock(os.path.join(self.wp_to_sg_dir_path, json_file_name))

        for action, action_history in wp_file_data['action'].items():
            if (action_history['success'] - action_history['failure']) <= -self.plan_failure_threshold:
                _, failed_subgoal_str = language_action_to_subgoal(action, waypoint)
                failed_subgoal_lists.append(failed_subgoal_str)

        return failed_subgoal_lists
    
    def retrieve_total_failed_counts(self, waypoint):
        json_file_name = f"{waypoint}.json"
        lst_dir = os.listdir(self.wp_to_sg_dir_path)

        if json_file_name not in lst_dir:
            return 0
        
        total_failure_counts = 0
        wp_file_data = read_json_data_with_shared_lock(os.path.join(self.wp_to_sg_dir_path, json_file_name))

        for action, action_history in wp_file_data['action'].items():
            action_failure_count = action_history['success'] - action_history['failure']
            total_failure_counts += action_failure_count

        return total_failure_counts

    def reset_success_failure_history(self, item_name):
        is_succeeded, _ = self.is_succeeded_waypoint(item_name)
        if is_succeeded:
            return

        json_file_name = f"{item_name}.json"
        json_path = os.path.join(self.wp_to_sg_dir_path, json_file_name)
        if not os.path.exists(json_path):
            return

        with self._lock:
            with open(json_path, "r+") as fp:
                fcntl.flock(fp, fcntl.LOCK_EX)
                wp_file_data = json.load(fp)

                for action in wp_file_data['action'].keys():
                    wp_file_data['action'][action] = {'success': 0, 'failure': 0}

                fp.seek(0)
                fp.truncate()
                json.dump(wp_file_data, fp)
                fcntl.flock(fp, fcntl.LOCK_UN)

        self.logger.info(f"Reset success/failure history for {item_name}")
