import torch
import pickle
import numpy as np
import faiss
import os
import gc

class ResourceManager:
    """Singleton resource manager for managing shared resources such as embedding models and faiss indices"""
    
    _instance = None

    def __init__(self):
       pass
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ResourceManager, cls).__new__(cls)
            cls._instance.embed_model = None
            cls._instance.faiss_index = None
            cls._instance.next_actions = None
            cls._instance.retrieve_rtg = None
            cls._instance.sim_trajectories = None
            cls._instance.faiss_loaded = False
            cls._instance.trajectories_indices = None
            cls._instance.trajectories_position = None
        return cls._instance
    
    def get_embed_model(self, EmbeddingTransformer):
        """Get embedding model, load if not already loaded"""
        current_file_dir = os.path.dirname(os.path.abspath(__file__))
        embed_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file_dir)))
        if self.embed_model is None:
            model_path = os.path.join(embed_root, "saved_model", "DT_embedding", "dt.pt")
            pickle_path = os.path.join(embed_root, "saved_model", "DT_embedding", "normalize_dict.pkl")
            print(f"Loading embedding model from {model_path}")

            with open(pickle_path, 'rb') as f:
                normalize_dict = pickle.load(f)
            
            self.embed_model = EmbeddingTransformer(state_dim=16, act_dim=1, 
                                             state_mean=normalize_dict["state_mean"],
                                             state_std=normalize_dict["state_std"], K=48)
            self.embed_model.load_net(model_path)
            self.embed_model.eval()
            print("Embedding model loaded.")
        return self.embed_model
    
    def get_faiss_resources(self):
        """Get faiss index and related resources, load if not already loaded"""

        current_file_dir = os.path.dirname(os.path.abspath(__file__))
        code_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file_dir)))
        base_path = os.path.join(code_root, "encodings_auction")
        if not self.faiss_loaded:
            try:
                print("Loading Faiss index...")
                self.next_actions = np.load(os.path.join(base_path, "next_actions.npy"))
                self.retrieve_rtg = np.load(os.path.join(base_path, "retrieve_rtgs.npy"))
                faiss_path = os.path.join(base_path, "encodings.index")
                self.trajectories_indices = np.load(os.path.join(base_path, "trajectory_indices.npy"))
                self.trajectories_position = np.load(os.path.join(base_path, "position_indices.npy"))
                trajectories_path = os.path.join(base_path, "trajectories.pkl")



                with open(trajectories_path, "rb") as f:
                    self.sim_trajectories = pickle.load(f)
                self.faiss_index = faiss.read_index(faiss_path)
                self.faiss_loaded = True
                print(f"Faiss index loaded with {self.faiss_index.ntotal} vectors.")
            except Exception as e:
                print(f"Error loading Faiss index: {e}")
                # Create empty index as fallback
                self.next_actions = np.zeros((0, 1), dtype=np.float32)
                self.retrieve_rtg = np.zeros((0, 1), dtype=np.float32)
                dimension = 64
                self.faiss_index = faiss.IndexFlatL2(dimension)
        
        return self.faiss_index, self.next_actions, self.retrieve_rtg,self.sim_trajectories,self.trajectories_indices,self.trajectories_position
    
    def release_resources(self):
        """Release resources"""
        if self.embed_model is not None:
            del self.embed_model
            self.embed_model = None
        
        if self.faiss_index is not None:
            del self.faiss_index
            self.faiss_index = None
        
        if self.next_actions is not None:
            del self.next_actions
            self.next_actions = None
            
        if self.retrieve_rtg is not None:
            del self.retrieve_rtg
            self.retrieve_rtg = None

        if self.sim_trajectories is not None:
            del self.sim_trajectories
            self.sim_trajectories = None

        if self.trajectories_indices is not None:
            del self.trajectories_indices
            self.trajectories_indices = None

        if self.trajectories_position is not None:
            del self.trajectories_position
            self.trajectories_position = None
        self.faiss_loaded = False
        gc.collect()
        print("Resources released.")
        
    @classmethod
    def reset_instance(cls):
        """Reset singleton instance to force reloading resources on next access"""
        if cls._instance is not None:
            cls._instance.release_resources()
            cls._instance = None
            gc.collect()