# Dynamic allocation
import sys
from dataclasses import dataclass
import argparse
import ray
from vllm import LLM, SamplingParams
import time
import os
import json
from datasets import load_dataset
import random, copy
import yaml
# Define configuration class
# @dataclass
class LMCallingConfig:
    model: str
    max_model_len: int = 16384 # Maximum input length
    tensor_parallel_size: int = 1
#     dtype: str = "float16"
    trust_remote_code: bool = True
#     gpu_memory_utilization: float = 0.9
    def __init__(self, **kwargs):
        # Iterate through all passed parameters
        print("Model initialization parameters: ")
        for key, value in kwargs.items():
            print(f"{key}: {value}")
            if hasattr(self, key):
                # If the class has this attribute, assign the value
                setattr(self, key, value)
            else:
                # If the class doesn't have this attribute, add it as a new attribute to the instance
                setattr(self, key, value)

# Define remote class to load and manage models
class ModelLoader:
    def __init__(self, llm_config: LMCallingConfig, **kwargs):
        self.llm = LLM(**llm_config.__dict__)
        self.model_name = llm_config.model
        self.short_model_name = os.path.basename(self.model_name)
        self.gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "unknown")
        self.save_path = "" # Temporary storage path, deleted after complete execution

    def get_model(self):
        return self.llm
    
    def get_temp_save_path(self, output_path):
        # Get temporary storage path
        gpu_id = self.gpu_id[0] # Avoid self.gpu_id = str(1,2)
        output_dir, filename = os.path.split(output_path)
        filename_wo_ext, file_ext = os.path.splitext(filename)
        output_dir = os.path.join(output_dir, f'{filename_wo_ext}_backup')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        file_name = f"{output_dir}/gpu_{gpu_id}_{filename}"
        self.save_path = file_name

    def chat_inference(self, prompt_list, output_path, sampling_params: SamplingParams):
        self.get_temp_save_path(output_path)
        msgs = self.pre_process(prompt_list) 
        res = []
        outputs = self.llm.chat(
            msgs,
            sampling_params=sampling_params,
            use_tqdm=True
        )
        print("Processing complete, starting post_process")
        for i, output in enumerate(outputs): 
            response = output.outputs[0].text
            prompt_list[i] = self.post_process(prompt_list[i], response) # Post-processing, by default adds an extra key:value, key is short_model_name
        # Backup each batch, backup in {same_name_folder}_backup
        print("post_process completed")
        with open(self.save_path, 'a', encoding='utf-8') as f:
            for result in prompt_list:
                f.write(json.dumps(result, ensure_ascii=False) + '\n')
        return prompt_list
    
    def pre_process(self, prompt_list): # Get msg from prompt, default as follows
        msgs = []
        for it in prompt_list:
            msg = it['messages'] # msg = []
            msgs.append(msg)
        return msgs

    def post_process(self, prompt, response): # Post-process results, by default add an extra key:value where key is short_model_name
        smn = self.short_model_name
        prompt[smn] = response
        return prompt
    def get_save_path(self):
        return self.save_path

# Create a custom ModelLoader subclass
class CustomModelLoader(ModelLoader):
    def __init__(self, llm_config: LMCallingConfig, **kwargs):
        super().__init__(llm_config)
        print("Use CustomModelLoader.......")
    def pre_process(self, prompt_list): # Get msg from prompt, default as follows
        msgs = []
        for it in prompt_list:
            # For data converted from parquet to jsonl, read like this
            # # Needs adjustment here
            try: 
                content = it['prompt'][-1]['content']
            except:
                content = it['prompt']
            msg = [
                # {
                #     'role': 'system',
                #     'content': ""
                # },
                {
                    'role': 'user',
                    'content': content
                }
            ]
            msgs.append(msg)
        return msgs


# Define the main class that manages model inference
class ModelInferenceManager:
    def __init__(self, model_params: dict, model_loader_class=ModelLoader):
        # Model parameters
        self.model_name = model_params['model']
        self.tensor_parallel_size = model_params.get('tensor_parallel_size', 1)
        self.total_gpus = model_params.pop('total_gpus')
        # Pass in a custom ModelLoader subclass
        self.model_loader_class = model_loader_class
        # Initialize LMCallingConfig
        self.llm_config = LMCallingConfig(**model_params)

        # Put ray initialization and model loading into __init__
        ray.init()
        # Dynamically specify num_gpus and use the provided custom ModelLoader class
        ModelLoaderRemote = ray.remote(num_gpus=self.tensor_parallel_size)(self.model_loader_class)
        # Compute data-parallel count. For example, with 8 GPUs and tp=2, ddp_count=4
        self.ddp_count = self.total_gpus // self.tensor_parallel_size
        print(f'Data parallelism: {self.ddp_count}')
        configs = [self.llm_config] * self.ddp_count
        # Load models
        self.loaders = [ModelLoaderRemote.remote(config) for config in configs]

        # KEY_WORDS
        self.ORDER_KEY = "__order_idx__"

    def get_prompt(self):
        prompt_list = []
        
        # Get file extension
        file_extension = os.path.splitext(self.input_path)[1].lower()
        
        if file_extension == '.parquet':
            # Handle parquet files
            dataset = load_dataset('parquet', data_files=self.input_path)['train']
            for i, item in enumerate(dataset):
                # Convert from parquet format to prompt_list format
                prompt_item = dict(item)  # Convert to regular dict
                prompt_item[self.ORDER_KEY] = i  # Add order identifier
                prompt_list.append(prompt_item)
        elif file_extension == '.jsonl':
            # Handle jsonl files
            with open(self.input_path) as fin:
                for i, line in enumerate(fin):
                    s = json.loads(line)
                    # Add order identifier
                    s[self.ORDER_KEY] = i
                    prompt_list.append(s)
        else:
            raise ValueError(f"Unsupported file format: {file_extension}. Only .parquet and .jsonl files are supported.")
        
        return prompt_list

    def distribute_tasks(self, loaders, chunks, sampling_params):
        num_loaders = len(loaders)
        futures = []
        in_progress = {}
        for i, chunk in enumerate(chunks):
            print(f"Current progress: {i + 1}/{len(chunks)}")
            if len(in_progress) < num_loaders:
                # Initial phase: try to fill all GPUs
                loader_index = len(in_progress)
                future = loaders[loader_index].chat_inference.remote(chunk, self.output_path, sampling_params)
#                 try:
#                     future = loaders[loader_index].remote(chunk, output_path, sampling_params)
#                 except Exception as e:
#                     print(f"Caught other exception: {e}")
                futures.append(future)
                in_progress[future] = loader_index
            else:
                # When all GPUs are busy, wait for the earliest task to complete
                ready_futures, _ = ray.wait(list(in_progress.keys()), num_returns=1)
                ready_future = ready_futures[0]
                loader_index = in_progress.pop(ready_future)
                new_future = loaders[loader_index].chat_inference.remote(chunk, self.output_path, sampling_params)
                futures.append(new_future)
                in_progress[new_future] = loader_index

        results = ray.get(futures)
        return results

    def load_backup_data(self):
        backup_data = {}
        backup_data_list = []
        output_dir, filename = os.path.split(self.output_path)
        filename_wo_ext, file_ext = os.path.splitext(filename)
        output_dir = os.path.join(output_dir, f'{filename_wo_ext}_backup')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        for root, dirs, files in os.walk(output_dir):
            dirs[:] = [d for d in dirs if not d.startswith('.ipynb')]
            for file in files:
                if file.endswith('.jsonl') or file.endswith('.json'):
                    file_path = os.path.join(root, file)
                    with open(file_path, 'r', encoding='utf-8') as f:
                        for line in f:
                            item = json.loads(line)
                            order_idx = item.get(self.ORDER_KEY)
                            rollout_idx = item.get('rollout_idx', 0)  # Default 0 to be compatible with old data
                            if order_idx is not None:
                                # Use (order_idx, rollout_idx) as unique identifier
                                backup_key = (order_idx, rollout_idx)
                                backup_data[backup_key] = True
                                backup_data_list.append(item)
        return backup_data, backup_data_list

    def filter_prompt_list(self, prompt_list, backup_data):
        # Filter out entries already processed based on backup data
        filtered_list = []
        for prompt in prompt_list:
            order_idx = prompt.get(self.ORDER_KEY)
                            rollout_idx = prompt.get('rollout_idx', 0)  # Default 0 to be compatible with old data
            backup_key = (order_idx, rollout_idx)
            if backup_key not in backup_data:
                filtered_list.append(prompt)
        return filtered_list

    def sort_order(self, original_prompt_list, backup_data_list):
        # Sort by rollout_idx first, then by ORDER_KEY
        sorted_list = sorted(backup_data_list, key=lambda x: (x.get('rollout_idx', 0), x[self.ORDER_KEY]))
        return sorted_list

    def batch_inference(self, data_params: dict, sampling_params: dict, rollout_n: int = 1):
        # Input parameters
        sampling_params = SamplingParams(**sampling_params)
        # Data information
        self.input_path = data_params['input_path']
        self.output_path = data_params['output_path']
        self.batch_size = data_params['batch_size']
        # Extract rollout_n parameter
        self.rollout_n = rollout_n
        prompt_list = self.get_prompt()
        for i, item in enumerate(prompt_list):
            item['prompt_idx'] = i
        new_prompt_list = []
        for i in range(self.rollout_n):
            rollout_batch = copy.deepcopy(prompt_list)
            # Add rollout_idx identifier for each rollout round
            for item in rollout_batch:
                item['rollout_idx'] = i
            new_prompt_list.extend(rollout_batch)
        prompt_list = copy.deepcopy(new_prompt_list)
        print(f"Rollout count={self.rollout_n}, total {len(prompt_list)} items after rollout")
#         prompt_list = prompt_list[:25]
        length = len(prompt_list)
        original_prompt_list = copy.deepcopy(prompt_list)
        print(f"output_path: {self.output_path}")
        print(f'Total data items: {length}')

        # Read from backup
        backup_data, _ = self.load_backup_data()
        # Filter information already processed
        prompt_list = self.filter_prompt_list(prompt_list, backup_data)
        if len(prompt_list) != length:
            print(f"Read data from backup, remaining items to process: {len(prompt_list)}")
            if len(prompt_list) == 0:
                print("All data processed, no further processing needed")
                # Restore the original order by reading backup data from file
                _, backup_data_list = self.load_backup_data()
                backup_data_list = self.sort_order(original_prompt_list, backup_data_list)
                with open(self.output_path, 'w', encoding='utf-8') as f:
                    for item in backup_data_list:
                        # Remove internal identifiers
                        if self.ORDER_KEY in item:
                            del item[self.ORDER_KEY]
                        if 'rollout_idx' in item:
                            del item['rollout_idx']
                        f.write(json.dumps(item, ensure_ascii=False) + '\n')
                    print(f"Written to {self.output_path}")
                return 0
        length = len(prompt_list)
        # Take the minimum of batch_size and length/ddp_count
        print(f"batch_size = {self.batch_size}")
        self.batch_size = min(self.batch_size, length // (self.ddp_count))
        if (length // (self.ddp_count)) == self.batch_size:
            print(f"Since batch_size > data-parallel count, new batch_size = {self.batch_size}")
        # Data chunking
        chunks = []
        for i in range(0, length, self.batch_size):
            chunks.append(prompt_list[i: min(i + self.batch_size, length)])

        # Assign tasks
        results = self.distribute_tasks(self.loaders, chunks, sampling_params)
        
        # Restore the original order by reading backup data from file
        _, backup_data_list = self.load_backup_data()
        backup_data_list = self.sort_order(original_prompt_list, backup_data_list)
        with open(self.output_path, 'w', encoding='utf-8') as f:
            for item in backup_data_list:
                # Remove internal identifiers
                if self.ORDER_KEY in item:
                    del item[self.ORDER_KEY]
                if 'rollout_idx' in item:
                    del item['rollout_idx']
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
            print(f"Written to {self.output_path}")
        # Temporary files can be deleted
        save_paths = ray.get([loader.get_save_path.remote() for loader in self.loaders])
        for path in save_paths:
            if path and os.path.exists(path):
                try:
                    os.remove(path)
                    print(f"Deleted temporary file: {path}")
                except Exception as e:
                    print(f"Error deleting file {path}: {e}")
        directory, _ = os.path.split(save_paths[0])
        try:
            if os.path.exists(directory) and os.path.isdir(directory):
                if not os.listdir(directory):
                    os.rmdir(directory)
                    print(f"Directory {directory} is empty, deleted.")
                else:
                    print(f"Directory {directory} is not empty, not deleted.")
            else:
                print(f"Specified path {directory} is not a valid directory.")
        except Exception as e:
            print(f"An error occurred: {e}")
        return results

    

# Usage example
if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='VLLM Inference with Config File')
    parser.add_argument('--config', type=str, default='config.yaml', 
                       help='Path to configuration file (default: config.yaml)')
    parser.add_argument('--model', type=str, default=None,
                       help='Override model path from config')
    parser.add_argument('--model-name', type=str, default=None,
                       help='Model name to override config')
    parser.add_argument('--model-smn', type=str, default=None,
                       help='Model short name to override config')
    parser.add_argument('--gpus', type=int, default=None,
                       help='Override total GPU count from config')
    parser.add_argument('--input_dir', type=str, default=None,
                       help='Override input directory from config')
    
    args = parser.parse_args()
    
    # Read configuration file
    # config_path = os.path.join(os.path.dirname(__file__), args.config)
    config_path = args.config
    print(config_path)
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    # Command-line arguments override configuration file
    if args.model:
        config['model']['name'] = os.path.abspath(args.model)
    if args.model_name is not None:
        config['model']['name'] = args.model_name
    if args.model_smn is not None:
        config['model']['smn'] = args.model_smn
    if args.gpus:
        config['system']['total_gpus'] = args.gpus
    if args.input_dir:
        config['data']['input_dir'] = args.input_dir
    
    # Extract configuration parameters
    model_name = config['model']['name']
    gpu_counts = config['system']['total_gpus']
    batch_size = config['system']['batch_size']
    input_dir = config['data']['input_dir']
    output_base_dir = config['data']['output_dir']
    input_files = config['data']['input_files']
    rollout_n = config['sample']['rollout_n']
    
    # Create output directory
    # Use smn from config, otherwise use the basename of model name
    smn = config['model'].get('smn', os.path.basename(model_name))
    output_dir = os.path.join(output_base_dir, smn)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    print(f"Model name: {smn}")
    print(f"Input directory: {input_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Number of GPUs: {gpu_counts}")
    print(f"Batch size: {batch_size}")
    print(f"Number of rollouts: {rollout_n}")
    
    # Model parameters
    model_params = {
        'model': model_name,
        'max_model_len': config['model']['max_model_len'],
        'tensor_parallel_size': config['model']['tensor_parallel_size'],
        'total_gpus': gpu_counts,
        'gpu_memory_utilization': config['model']['gpu_memory_utilization'],
    }
    
    # Sampling parameters
    sampling_params = config['sampling']
    
    # Use the custom ModelLoader subclass
    manager = ModelInferenceManager(
        model_params=model_params,
        model_loader_class=CustomModelLoader
    )
    
    # Process each input file
    for input_file in input_files:
        input_path = os.path.join(input_dir, input_file)
        output_path = os.path.join(output_dir, input_file)
        
        # Ensure output file is in jsonl format
        file_name = os.path.splitext(output_path)[0]
        output_path = f'{file_name}.jsonl'
        
        print(f"\nProcessing file: {input_file}")
        print(f"Input path: {input_path}")
        print(f"Output path: {output_path}")
        
        data_params = {
            'input_path': input_path,
            'output_path': output_path,
            'batch_size': batch_size
        }
        
        results = manager.batch_inference(
            data_params=data_params,
            sampling_params=sampling_params,
            rollout_n=rollout_n
        )

