from torch.utils.data import Dataset
import h5py
import os
import hydra
import numpy as np
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
from hydra.utils import to_absolute_path
import torch
import random
import re
import sys
import string
import yaml
from omegaconf import OmegaConf
from easydict import EasyDict
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.obs_utils as ObsUtils
from collections import OrderedDict
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.log_utils as LogUtils
from libero.lifelong.datasets import MySequenceDataset
from libero.lifelong.utils import get_task_embs
from libero.libero.benchmark import get_benchmark


class Logging:
    def __init__(self, *files):
        self.files = files

    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush()  # Ensure content is written to file immediately

    def flush(self):
        for f in self.files:
            f.flush()

class DataModule():

    def __init__(self, benchmark, cfg, logger=None):
        self.benchmark = benchmark
        self.benchmark_path = os.path.join(cfg.folder, self.benchmark.get_benchmark_path())
        self.logger = logger
        self.hdf5_file_path = None
        self.hdf5_file = None
        
    
    def generate_dataset_by_task(self, obs_modality, dataset_modality, task_list, initialize_obs_utils=True, seq_len=1, para_task_description=True):
        
        if initialize_obs_utils:
            ObsUtils.initialize_obs_utils_with_obs_specs({"obs": obs_modality})

        self.all_obs_keys = self.__parse_modality(obs_modality)
        self.all_dataset_keys = self.__parse_modality(dataset_modality)
        self.obs_keys_in_memory = self.__get_obs_keys_in_memory(self.all_obs_keys)
       
        self.__set_hdf5_file_path( os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(0)) ) )
        shape_meta = self.__get_shape_metadata(all_obs_keys=self.all_obs_keys, verbose=False)
        
        dataset_all = MySequenceDataset()
        
        for task_id in task_list:
            self.__set_hdf5_file_path( os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(task_id)) ) )

            self.__load_hdf5_file()

            task_description = os.path.basename( self.hdf5_file[f"data"].attrs["bddl_file_name"] ).split(".")[0]
            print(f"start to load task_{task_id}: {task_description}")
            demo_names = list(self.hdf5_file[f"data"].keys())

            inds = np.argsort([int(demo_name[5:]) for demo_name in demo_names])
            demo_names = [demo_names[i] for i in inds]        
            
            for demo_name in LogUtils.custom_tqdm(demo_names):
            
                data = self.__extract_data_from_demo(demo_name, self.obs_keys_in_memory, self.all_dataset_keys, para_task_description)
               
                
                dataset_temp = MySequenceDataset(self.all_obs_keys, self.obs_keys_in_memory, self.all_dataset_keys, data, seq_len)
                
                dataset_all = dataset_all + dataset_temp
                
            self.__close()
        
        return dataset_all, shape_meta, 
      
        
    def generate_dataset_by_demo(self, obs_modality, dataset_modality, demo_list, initialize_obs_utils=True, seq_len=1, para_task_description=True):
        
        if initialize_obs_utils:
            ObsUtils.initialize_obs_utils_with_obs_specs({"obs": obs_modality})

        all_obs_keys = self.__parse_modality(obs_modality)
        all_dataset_keys = self.__parse_modality(dataset_modality)
        
        self.__set_hdf5_file_path( os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(0)) ) )
        shape_meta = self.__get_shape_metadata(all_obs_keys=all_obs_keys, verbose=False)

        dataset_all = MySequenceDataset()
        
        idx_to_task_demo_ids_list = self.get_idx_to_task_demo_ids_list()
        for demo_idx in demo_list:
            task_id, demo_name = idx_to_task_demo_ids_list[demo_idx]
            self.__set_hdf5_file_path( os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(task_id)) ) )
            self.__load_hdf5_file()
            
            demo = self.hdf5_file[f"data/{demo_name}"]
            
            data = self.__extract_data_from_demo(demo, all_obs_keys, all_dataset_keys, para_task_description)
                
            dataset_temp = MySequenceDataset(all_obs_keys, all_dataset_keys, data, seq_len)
            dataset_all = dataset_all + dataset_temp
        
            self.__close()
            
        return dataset_all, shape_meta, 
    
    def process_dataset(self, dataset, cfg):
        
        descriptions = []
        for demo_data in dataset.data:
            descriptions = descriptions + demo_data['demo_description_list']
        
        task_embs = get_task_embs(cfg, descriptions)
        
        start_pointer = 0
        for idx, demo_data in enumerate(dataset.data):
            emb_list = task_embs[start_pointer: start_pointer + len(demo_data['demo_description_list'])]
            dataset.data[idx]['demo_emb_list'] = emb_list
            start_pointer = start_pointer + len(demo_data['demo_description_list'])
            
        assert start_pointer == task_embs.shape[0]
            
        return dataset
        
    
    def rewrite_task_description(self):
        """
        input: 
            the hdf5 file path has been setup in the constructor
        output:
            the task description has been rewritten and saved in the hdf5 file
        """

        # Configure the experiment directory path
        log_file_path = os.path.join(f'descriptions_{self.benchmark.get_benchmark_path()}.txt')
        log_file = open(log_file_path, 'a')
        sys.stdout = Logging(sys.stdout, log_file)

        # get the list of .hdf5 files in the directory
        files = os.listdir(self.benchmark_path)
        hdf5_file_paths = [os.path.join(self.benchmark_path, f) for f in files if f.endswith('.hdf5')]
        
        model = AutoModelForCausalLM.from_pretrained( 
                "microsoft/Phi-3-mini-4k-instruct",  
                device_map="cuda",  
                torch_dtype="auto",  
                trust_remote_code=True,  
                cache_dir=to_absolute_path("./Phi-3-mini-4k-instruct")
                ) 

        tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", cache_dir=to_absolute_path("./Phi-3-mini-4k-instruct")) 

        pipe = pipeline( 
            "text-generation", 
            model=model, 
            tokenizer=tokenizer, 
        ) 

        generation_args = { 
            "max_new_tokens": 500, 
            "return_full_text": False, 
            "temperature": 0.0, 
            "do_sample": False, 
        }     

        for hdf5_file_path in hdf5_file_paths:
            self.__set_hdf5_file_path(hdf5_file_path)

            self.__generate_prompt()
            
            self.__load_hdf5_file(permission='a')
            
            data = self.hdf5_file['data']
            
            for key in data:
                demo = data[key]

                bddl_file_path = data.attrs["bddl_file_name"]
                bddl_file_name = os.path.basename(bddl_file_path).split(".")[0]
                task_description = " ".join(bddl_file_name.split("_"))
                
                answer_accepted = 0
                demo_description_list = []
                demo_description_list.append(task_description)
                while not answer_accepted > 10:
                    content = random.choices(self.prompt_list, k=1)[0] + task_description
                    print(content)
                    
                    messages = [
                        {"role": "user", "content": content},
                    ]

                    outputs = pipe(messages, **generation_args) 
                   
                    answer = outputs[0]["generated_text"]
                    answer = ''.join([str(char) for char in answer if char in string.printable])
                    if len(answer) < 3 * len(task_description):
                        answer_accepted = answer_accepted + 1
                        demo_description_list.append(answer)
                    
                string_dt = h5py.special_dtype(vlen=str)
                if "demo_description_list" in demo:
                    del demo["demo_description_list"]
                    
                demo.create_dataset("demo_description_list", data=np.array(demo_description_list, dtype=string_dt))
                demo.attrs["task_description"] = task_description
                
                demo_description_list = demo['demo_description_list'][:]
                print( demo_description_list)
                print("=====================================")
                
            self.__close()
                    
                        

    def get_idx_to_task_demo_ids_list(self):

        n_tasks = self.benchmark.n_tasks
        idx_to_task_demo_ids_list = []

        for i in range(n_tasks):
            self.__set_hdf5_file_path( os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(i)) ) )

            self.__load_hdf5_file()
            
            
            for demo_name in self.hdf5_file['data']:
                idx_to_task_demo_ids_list.append((i, demo_name))
            
            self.__close()
        
        return idx_to_task_demo_ids_list

    def get_task_names(self):
        
        n_tasks = self.benchmark.n_tasks
        task_names = []

        for i in range(n_tasks):
            hdf5_file_path = os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(i)) )
            self.__set_hdf5_file_path(hdf5_file_path)
            self.__load_hdf5_file()
            
            bddl_file_path = self.hdf5_file[f"data"].attrs["bddl_file_name"]
            bddl_file_name = os.path.basename(bddl_file_path).split(".")[0]
            task_description = " ".join(bddl_file_name.split("_"))
            task_names.append(task_description)
            
            self.__close()

        return task_names


    def get_demo_names_and_ids(self, task_id: int):
        self.__load_hdf5_file()
        
        demo_names = self.hdf5_file[f'data/task_{task_id}/data'].keys()
        demo_ids = [ int(demo_name[5:]) for demo_name in demo_names ] 
        
        self.__close()
        return demo_names, demo_ids

    def get_task_description(self, task_id: int):

        hdf5_file_path = os.path.join(self.benchmark_path, os.path.basename(self.benchmark.get_task_demonstration(task_id)) )
        self.__set_hdf5_file_path(hdf5_file_path)
        self.__load_hdf5_file()
        
        task_description = os.path.basename( self.hdf5_file[f"data"].attrs["bddl_file_name"] ).split(".")[0]
        
        self.__close()

        return task_description

    # private methods    
    def __print_keys_and_attributes(self, hdf5_file, indent=0):
        """
        Recursively prints all keys and their attributes in an HDF5 file or group.

        Parameters:
        hdf5_file (h5py.File or h5py.Group): The HDF5 file or group to print.
        indent (int): The indentation level for nested groups/datasets.
        """
        
        for key in hdf5_file:
            item = hdf5_file[key]
            indent_str = " " * indent
            if isinstance(item, h5py.Group):
                print(f"{indent_str}Group: {key}")
                # Print attributes of the group
                if item.attrs and True:
                    print(f"{indent_str}  Attributes:")
                    for attr_key, attr_value in item.attrs.items():
                        print(f"{indent_str}    {attr_key}:")
                # Recursively print the contents of the group
                self.__print_keys_and_attributes(item, indent + 4)
            elif isinstance(item, h5py.Dataset):
                print(f"{indent_str}Dataset: {key}")
                # Print attributes of the dataset
                if item.attrs and True:
                    print(f"{indent_str}  Attributes:")
                    for attr_key, attr_value in item.attrs.items():
                        print(f"{indent_str}    {attr_key}: ")
    
    def __generate_prompt(self):
        self.prompt_list = []
        self.prompt_list.append("rewrite this sentence in English and return the result only: ")
        self.prompt_list.append("Rephrase the given sentence a little bit in English and return only the revised version: ")
        # self.prompt_list.append("Paraphrase the following sentence a little bit in English and return the result only: ")
        self.prompt_list.append("Reconstruct this sentence in English and provide only the result: ")
        self.prompt_list.append("Summarize the following sentence in English and return the result only: ")
        self.prompt_list.append("Modify this sentence to be more polite in English and return the result only: ")
        # self.prompt_list.append("Modify this sentence to be more impolite in English and return the result only: ")
        self.prompt_list.append("Paraphrase the following sentence in English politely and return the result only: ")
        # self.prompt_list.append("Simplify the following sentence in English and return the result only: ")
        # self.prompt_list.append("Rephrase the given sentence in English with an angry tone and return only the revised version: ")
        self.prompt_list.append("rewrite this sentence in English with a happy tone and return the result only: ")
        self.prompt_list.append("Polish this sentence in English and return the result only: ")
    
    
    
    def __parse_modality(self, modality):
        keys =  []
        for modality_name, modality_list in modality.items():
            keys += modality_list
            
        return keys
    
    def __get_obs_keys_in_memory(self, all_obs_keys):
        obs_keys_in_memory = []
        for k in all_obs_keys:
            if ObsUtils.key_is_obs_modality(k, "low_dim"):
                obs_keys_in_memory.append(k)
        return obs_keys_in_memory


    def __extract_data_from_demo(self, demo_name, all_obs_keys, all_dataset_keys, para_task_description=True):
        
        demo = self.hdf5_file[f"data/{demo_name}"]
        data = dict()
        data['obs'] = dict()
         
        for key in all_dataset_keys:
            data[key] = demo[key][()].astype('float32')
       
        for key in all_obs_keys:
            data['obs'][key] = demo['obs'][key][()].astype('float32')
            
        

        if para_task_description:   
            description_list = demo["demo_description_list"][:]
            demo_description_list = [s.decode("utf-8").strip('"') for s in description_list]
            data['demo_description_list'] = demo_description_list
        else:
            bddl_file_path = self.hdf5_file[f"data"].attrs["bddl_file_name"]
            bddl_file_name = os.path.basename(bddl_file_path).split(".")[0]
            task_description = " ".join(bddl_file_name.split("_"))
            data['demo_description_list'] = [task_description]
         
        data['length'] = demo.attrs["num_samples"]
        data['hdf5_file_path'] = self.hdf5_file_path
        data['demo_name'] = demo_name
        
        return data
        

    def __get_shape_metadata(self, all_obs_keys=None, verbose=False):
        """
        Retrieves shape metadata from dataset.

        Args:
            dataset_path (str): path to dataset
            all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities
                present in the file are used.
            verbose (bool): if True, include print statements

        Returns:
            shape_meta (dict): shape metadata. Contains the following keys:

                :`'ac_dim'`: action space dimension
                :`'all_shapes'`: dictionary that maps observation key string to shape
                :`'all_obs_keys'`: list of all observation modalities used
                :`'use_images'`: bool, whether or not image modalities are present
                :`'use_depths'`: bool, whether or not depth modalities are present
        """

        shape_meta = {}

        # read demo file for some metadata
        self.__load_hdf5_file()
        
        f = self.hdf5_file
        demo_id = list(f["data"].keys())[0]
        demo = f["data/{}".format(demo_id)]

        # action dimension
        shape_meta['ac_dim'] = f["data/{}/actions".format(demo_id)].shape[1]

        # observation dimensions
        all_shapes = OrderedDict()

        if all_obs_keys is None:
            # use all modalities present in the file
            all_obs_keys = [k for k in demo["obs"]]

        for k in sorted(all_obs_keys):
            initial_shape = demo["obs/{}".format(k)].shape[1:]
            if verbose:
                print("obs key {} with shape {}".format(k, initial_shape))
            # Store processed shape for each obs key
            all_shapes[k] = ObsUtils.get_processed_shape(
                obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k],
                input_shape=initial_shape,
            )

        self.__close()

        shape_meta['all_shapes'] = all_shapes
        shape_meta['all_obs_keys'] = all_obs_keys
        shape_meta['use_images'] = ObsUtils.has_modality("rgb", all_obs_keys)
        
        return shape_meta
    
    
    def __repr__(self):
        if self.hdf5_file is None:
            self.__load_hdf5_file()
        
        self.__print_keys_and_attributes(self.hdf5_file)
        
        self.__close()
        
        return "print finished"
    
    def __set_hdf5_file_path(self, hdf5_file_path):
        self.hdf5_file_path = hdf5_file_path

    def __load_hdf5_file(self, permission='r'):
        if self.hdf5_file is not None:
            self.__close()
            
        self.hdf5_file = h5py.File(self.hdf5_file_path, permission, swmr=False, libver='latest')
    
    def __close(self):
        """
        Maybe close the file handle.
        """
        if self.hdf5_file is not None:
            self.hdf5_file.close()
        self.hdf5_file = None


@hydra.main(config_path="../configs", config_name="config", version_base=None)
def main(hydra_cfg):
    yaml_config = OmegaConf.to_yaml(hydra_cfg)
    cfg = EasyDict(yaml.safe_load(yaml_config))
    
    benchmark = get_benchmark(cfg.benchmark_name)(cfg.data.task_order_index)

      
    dm = DataModule(benchmark, cfg)    

    dm.rewrite_task_description()

    # task_dataset, shape_meta = dm.generate_dataset_by_task(
    #         obs_modality=cfg.data.obs.modality,
    #         dataset_modality=cfg.data.dataset.modality,
    #         task_list=[0],
    #         initialize_obs_utils=(0 == 0),
    #         seq_len=cfg.data.seq_len,
    #     )


    # task_demo_dataset, shape_meta = dm.generate_dataset_by_demo(
    #     obs_modality=cfg.data.obs.modality,
    #     dataset_modality=cfg.data.dataset.modality,
    #     demo_list=[ i  for  i in range(10)],
    #     initialize_obs_utils=(0 == 0),
    #     seq_len=cfg.data.seq_len,
    # )
    # print(f"shape meta: {shape_meta}")
    # print(f"task_dataset length: {len(task_dataset)}")

    # sliced_dataset = task_dataset[:10]
    # print(f"sliced_dataset length: {len(sliced_dataset)}")
    
    # print(f"demo_dataset length: {len(task_demo_dataset)}")
   
if __name__ == '__main__':
    main()
    