import os
import json
import glob
import noise_embed.torch_utils as torch_utils
import random
from datasets import Dataset
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, GenerationConfig
from dataclasses import dataclass
from typing import List, Dict, Any, Generator
from tqdm import tqdm
from datasets import Features, Value, Sequence
import argparse
import time
import math
import ray

from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, get_masked_input_and_mask
from vllm.distributed import tensor_model_parallel_all_reduce

def patch_vocab_embedding(noise_std):
    if not hasattr(VocabParallelEmbedding, 'original_forward'):
        VocabParallelEmbedding.original_forward = VocabParallelEmbedding.forward

    def noisy_forward(self, input_):
        if self.tp_size > 1:
            masked_input, input_mask = get_masked_input_and_mask(
                input_, self.shard_indices.org_vocab_start_index,
                self.shard_indices.org_vocab_end_index,
                self.shard_indices.num_org_vocab_padding,
                self.shard_indices.added_vocab_start_index,
                self.shard_indices.added_vocab_end_index)
        else:
            masked_input = input_
        
        output_parallel = self.quant_method.embedding(self, masked_input.long())
        
        noisy_ids = {128000, 128001, 128008, 128009, 151645, 151643}
        noise_mask = torch_utils.isin(input_, torch_utils.tensor(list(noisy_ids), device=input_.device))
        
        if noise_mask.any():
            noise = torch_utils.randn_like(output_parallel) * noise_std
            output_parallel = torch_utils.where(noise_mask.unsqueeze(-1), output_parallel + noise, output_parallel)

        if self.tp_size > 1:
            output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
            
        output = tensor_model_parallel_all_reduce(output_parallel)
        return output
    
    VocabParallelEmbedding.forward = noisy_forward
    print("VocabParallelEmbedding.forward has been successfully patched with noisy_forward.")

@dataclass
class Processor:
    tokenizer: AutoTokenizer
    system: str = ""
    n: int = 256

    def encode_sample(self, problem: str) -> List[int]:
        messages = []
        if self.system:
            messages.append({"role": "system", "content": self.system})
        messages.append({"role": "user", "content": problem})
        return self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)

    def process_samples(self, samples: List[Dict]) -> Generator[Dict[str, Any], None, None]:
        for sample in samples:
            input_ids = self.encode_sample(sample["problem"])
            for _ in range(self.n):
                yield {
                    "sample": {
                        "answer": str(sample.get("answer", "")),
                        "source": str(sample.get("source", "")),
                        "problem": str(sample.get("problem", ""))
                    },
                    "model_inputs": {
                        "prompt_token_ids": input_ids
                    }
                }

    def collect_results(self, results: List[Any]) -> Generator[Dict[str, List[List[int]]], None, None]:
        for result in results:
            predicts = [output.token_ids for output in result.outputs]
            yield {"predicts": predicts}

    def decode_sample(self, outputs: Dict[str, List[List[int]]]) -> Dict[str, List[str]]:
        predicts = []
        for predict in outputs["predicts"]:
            predicts.append(self.tokenizer.decode(predict, skip_special_tokens=True))
        return {"predicts": predicts}

@ray.remote(num_gpus=2)
def process_chunk_ray(task, args_dict):
    import noise_embed.torch_utils as torch_utils
    import gc
    from datasets import Dataset, Features, Value, Sequence

    if args_dict.get('use_noise'):
        patch_vocab_embedding(args_dict['noise_std'])
        print(f"Ray worker for chunk {task.get('chunk_id')}: Applied embedding noise patch.")

    result = None
    current_model = None
    vllm_engine = None
    tokenizer = None
    processor = None
    try:
        chunk = task['chunk']
        model_id = task['model_id']
        chunk_id = task['chunk_id']
        if vllm_engine is not None:
            del vllm_engine
        if tokenizer is not None:
            del tokenizer
        if processor is not None:
            del processor
        torch_utils.cuda.empty_cache()
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        processor = Processor(tokenizer, system="", n=args_dict['n'])
        generation_config = GenerationConfig.from_pretrained(model_id)
        stop_token_ids = ([generation_config.eos_token_id]
                            if not isinstance(generation_config.eos_token_id, list)
                            else generation_config.eos_token_id)
        stop_token_ids += [128009, 151643, 151644, 151645]
        
        vllm_engine = LLM(
            model=model_id,
            enforce_eager=True,
            tensor_parallel_size=2,
            max_model_len=20000
        )
        features = Features({
            'sample': {
                'answer': Value('string'),
                'source': Value('string'),
                'problem': Value('string')
            },
            'model_inputs': {
                'prompt_token_ids': Sequence(Value('int64'))
            }
        })
        inputs = Dataset.from_generator(
            processor.process_samples,
            gen_kwargs={"samples": chunk},
            features=features
        )
        sampling_params = SamplingParams(
            n=1,
            temperature=args_dict['temperature'],
            max_tokens=18000,
            detokenize=False,
            stop_token_ids=stop_token_ids
        )
        results = vllm_engine.generate(inputs["model_inputs"], sampling_params)
        outputs = Dataset.from_generator(
            processor.collect_results,
            gen_kwargs={"results": results}
        )
        outputs = outputs.map(
            processor.decode_sample,
            batched=False,
            num_proc=8,
            desc=f"Decode outputs (Chunk {chunk_id})"
        )
        processed_chunk = []
        for i in range(0, len(outputs), args_dict['n']):
            sample = inputs[i]["sample"]
            batch_predicts = outputs[i:i+args_dict['n']]["predicts"]
            sample["result"] = [pred for batch in batch_predicts for pred in batch]
            processed_chunk.append(sample)
        
        result = {
            'success': True,
            'model_id': model_id,
            'chunk_id': chunk_id,
            'data': processed_chunk
        }
    except Exception as e:
        import traceback
        tb = traceback.format_exc()
        result = {
            'success': False,
            'model_id': task.get('model_id', ''),
            'chunk_id': task.get('chunk_id', ''),
            'error': f"{str(e)}\n{tb}"
        }
    finally:
        if vllm_engine is not None:
            del vllm_engine
        if tokenizer is not None:
            del tokenizer
        if processor is not None:
            del processor
        torch_utils.cuda.empty_cache()
        gc.collect()
    return result

def save_results(results: Dict[str, List[Dict]], output_dir: str, params: dict):
    os.makedirs(output_dir, exist_ok=True)
    for model_id, data in results.items():
        if not data:
            continue
        tag = f"T{params['temperature']}_N{params['n']}_NOISE{params['use_noise']}_STD{params['noise_std']}_{os.path.basename(model_id).replace('/', '-')}"
        output_path = os.path.join(output_dir, f"{tag}.json")
        temp_path = output_path + '.tmp'
        try:
            with open(temp_path, 'w') as f:
                json.dump({
                    "meta": params,
                    "model_id": model_id,
                    "data": data
                }, f, indent=4)
            os.replace(temp_path, output_path)
        except Exception as e:
            print(f"Error saving results for {model_id}: {str(e)}")
            if os.path.exists(temp_path):
                os.remove(temp_path)

def load_existing_results(output_dir: str, model_id: str, params: dict):
    tag_base = f"T{params['temperature']}_N{params['n']}_NOISE{params['use_noise']}_STD{params['noise_std']}_{os.path.basename(model_id).replace('/', '-')}"
    
    merged_path = os.path.join(output_dir, f"{tag_base}.json")
    if os.path.exists(merged_path):
        try:
            with open(merged_path, 'r') as f:
                existing_data = json.load(f)
            existing_results = existing_data.get('data', [])
            print(f"Loaded {len(existing_results)} existing results from merged file {merged_path}")
            processed_problems = {item['problem'] for item in existing_results}
            print(f"Found {len(processed_problems)} unique problems in existing results")
            return existing_results, processed_problems
        except Exception as e:
            print(f"Error loading existing merged results: {str(e)}")
    
    existing_results = []
    chunk_pattern = os.path.join(output_dir, f"{tag_base}_chunk*.json")
    chunk_files = glob.glob(chunk_pattern)
    
    if chunk_files:
        print(f"Found {len(chunk_files)} chunk files matching pattern {tag_base}_chunk*.json")
        for chunk_file in chunk_files:
            try:
                with open(chunk_file, 'r') as f:
                    chunk_data = json.load(f)
                chunk_results = chunk_data.get('data', [])
                existing_results.extend(chunk_results)
                print(f"Loaded {len(chunk_results)} results from {os.path.basename(chunk_file)}")
            except Exception as e:
                print(f"Error loading chunk file {chunk_file}: {str(e)}")
    
    processed_problems = {item['problem'] for item in existing_results}
    print(f"Loaded total of {len(existing_results)} results from {len(chunk_files)} chunk files")
    print(f"Found {len(processed_problems)} unique problems in existing chunk results")
    
    return existing_results, processed_problems

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--use-noise', action='store_true', help="Enable noise embedding")
    parser.add_argument('--noise-std', type=float, default=0.01, help="Noise std if using noise embedding")
    parser.add_argument('--model-id', type=str, required=True, help="Model checkpoint path")
    parser.add_argument('--n', type=int, default=256, help="Sample count per input")
    parser.add_argument('--temperature', type=float, default=0.3, help="Sampling temperature")
    parser.add_argument('--save-path', type=str, default="", help="Output directory")
    parser.add_argument('--num-processes', type=int, default=16, help="Number of Ray tasks (each using 2 GPUs)")
    parser.add_argument('--data', type=str, default="", help="Input data file")
    parser.add_argument('--force', action='store_true', help="Force reprocessing of all data")
    parser.add_argument('--ray-address', type=str, default=None, help="Ray cluster address, e.g. 'auto' or 'ray://...'. If None, use local.")

    args = parser.parse_args()

    if args.use_noise:
        patch_vocab_embedding(args.noise_std)

    ray_kwargs = {}
    if args.ray_address:
        ray_kwargs["address"] = args.ray_address
    ray.init(**ray_kwargs, ignore_reinit_error=True)

    with open(args.data, "r") as f:
        raw_data = json.load(f)
    
    all_problems = {item["problem"] for item in raw_data}
    print(f"Found {len(all_problems)} unique problems in raw data")
    
    model_id = args.model_id
    num_processes = args.num_processes
    total_gpus = num_processes * 2

    print(f"Using {num_processes} remote Ray tasks with 2 GPUs each (total {total_gpus} GPUs required)")

    meta_params = {
        "use_noise": args.use_noise,
        "noise_std": args.noise_std,
        "model_id": args.model_id,
        "n": args.n,
        "temperature": args.temperature,
        "save_path": args.save_path,
        "data_file": args.data,
        "num_processes": args.num_processes,
        "gpus_per_process": 2,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }

    existing_results, processed_problems = load_existing_results(args.save_path, model_id, meta_params)
    
    if args.force:
        print("Force reprocessing enabled, ignoring existing results")
        processed_problems = set()
    
    remaining_problems = all_problems - processed_problems
    remaining_data = [item for item in raw_data if item["problem"] in remaining_problems]
    random.shuffle(remaining_data)
    
    print(f"Total unique problems in dataset: {len(all_problems)}")
    print(f"Already processed unique problems: {len(processed_problems)}")
    print(f"Remaining unique problems to process: {len(remaining_problems)}")
    print(f"Remaining data points to process: {len(remaining_data)}")
    
    if not remaining_data:
        print("All data has already been processed. Exiting.")
        return

    tasks = []
    results = {model_id: existing_results}
    failed_tasks = []

    samples_per_process = math.ceil(len(remaining_data) / num_processes)
    for process_id in range(num_processes):
        start_idx = process_id * samples_per_process
        end_idx = min(start_idx + samples_per_process, len(remaining_data))
        if start_idx >= len(remaining_data):
            break
        chunk = remaining_data[start_idx:end_idx]
        if not chunk:
            continue
        task = {'chunk': chunk, 'model_id': model_id, 'chunk_id': process_id}
        tasks.append(task)

    if not tasks:
        print("No tasks created. Exiting.")
        return

    args_dict = vars(args)

    ray_futures = []
    for task in tasks:
        ray_futures.append(process_chunk_ray.remote(task, args_dict))

    total_tasks = len(tasks)
    completed_tasks = 0
    save_interval = 1
    
    with tqdm(total=total_tasks) as pbar:
        while completed_tasks < total_tasks:
            ready, _ = ray.wait(ray_futures, num_returns=1, timeout=1500)
            if not ready:
                print("No results received for 25 minutes, checking worker status...")
                continue
            result = ray.get(ready[0])
            ray_futures.remove(ready[0])
            completed_tasks += 1
            pbar.update(1)
            if result['success']:
                results[model_id].extend(result['data'])
            else:
                failed_tasks.append({
                    'model_id': model_id,
                    'chunk_id': result['chunk_id'],
                    'error': result['error']
                })
            if completed_tasks % save_interval == 0:
                save_results(results, args.save_path, meta_params)
    
    save_results(results, args.save_path, meta_params)
    
    if failed_tasks:
        print("\nFailed tasks:")
        for task in failed_tasks:
            print(f"Chunk: {task['chunk_id']}, Error: {task['error']}")

if __name__ == "__main__":
    main()