import os
import sys
import argparse
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import pickle as pkl
from tqdm import tqdm
import time
import random
import json
import multiprocessing
import re
from transformers import AutoTokenizer

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from evaluate.data_utils import load_data, DATA_MAP, get_text_columns, get_max_length_dict
from SFT.train_anollm import get_run_name

os.environ["VLLM_USE_FLASHINFER"] = "0"
os.environ["VLLM_USE_V1"] = "0"
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

from vllm import LLM, SamplingParams
    

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default='vifd', 
                       choices=[d.lower() for d in DATA_MAP.keys()],
                       help="Name of datasets")
    parser.add_argument("--base_model_dir", type=str, required=True)
    parser.add_argument("--exp_dir", type=str, default=None)
    parser.add_argument("--setting", type=str, default='semi_supervised')
    parser.add_argument("--iteration", type=int, default=0)
    
    parser.add_argument("--data_dir", type=str, default='data')
    parser.add_argument("--n_splits", type=int, default=5)
    parser.add_argument("--split_idx", type=int, default=0)
    parser.add_argument("--train_ratio", type=float, default=0.5)
    parser.add_argument("--seed", type=int, default=42)
    
    parser.add_argument("--binning", type=str, 
                       choices=['quantile', 'equal_width', 'language', 'none', 'standard'],
                       default='standard')
    parser.add_argument("--n_buckets", type=int, default=10)
    
    parser.add_argument("--model", type=str, 
                       choices=['gpt2', 'distilgpt2', 'smol', 'smol-360', 'smol-1.7b'],
                       default='smol-360')
    parser.add_argument("--lora", action='store_true', default=False)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--no_random_permutation", action='store_true', default=False)
    
    # generation parameters
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for generated samples")
    parser.add_argument("--batch_size", type=int, default=32)
    
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.9)
    parser.add_argument("--max_model_len", type=int, default=2048)
    
    parser.add_argument("--use_normal_generation", action='store_true', default=True)
    parser.add_argument("--generation_temperature", type=float, default=1.0)
    parser.add_argument("--generation_top_p", type=float, default=0.95)
    parser.add_argument("--n_target_features", type=int, default=2)
    
    parser.add_argument("--remove_metadata_20news", action='store_true', default=False)
    
    args = parser.parse_args()
    
    if args.model == 'smol':
        args.model = 'HuggingFaceTB/SmolLM-135M'
    elif args.model == 'smol-360':
        args.model = 'HuggingFaceTB/SmolLM-360M'
    elif args.model == 'smol-1.7b':
        args.model = 'HuggingFaceTB/SmolLM-1.7B'
    
    if args.exp_dir is None:
        args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx)
    else:
        args.exp_dir = Path(args.exp_dir)
    
    return args


def row_to_text(row_dict, feature_order):
    parts = []
    for feature in feature_order:
        value = str(row_dict[feature]).strip()
        parts.append(f" {feature} is {value} ")
    return ",".join(parts)


def prepare_samples_metadata(X_train, y_train, n_target_features, tokenizer=None, max_model_len=2048):
    print("Preparing samples metadata for sequential generation...")
    
    normal_indices = np.where(y_train == 0)[0]
    print(f"Using {len(normal_indices)} normal samples to generate synthetic data")
    
    samples_metadata = []
    
    for idx in tqdm(normal_indices, desc="Preparing metadata"):
        row = X_train.iloc[idx]
        features = list(row.index)
        n_features = len(features)
        
        if n_features == 1 and tokenizer:
            feature_name = features[0]
            original_text = str(row[feature_name]).strip()
            
            text_tokens = tokenizer.encode(original_text, add_special_tokens=False)
            n_tokens = len(text_tokens)
            
            context_ratio = 0.8
            context_token_len = int(n_tokens * context_ratio)
            
            context_tokens = text_tokens[:context_token_len]
            target_tokens = text_tokens[context_token_len:]
            
            context_text_part = tokenizer.decode(context_tokens, skip_special_tokens=False).strip()
            target_text_part = tokenizer.decode(target_tokens, skip_special_tokens=False).strip()
            
            samples_metadata.append({
                'original_idx': idx,
                'chosen_row': row.to_dict(),
                'features_order': features,
                'context_features': [],
                'target_features': [feature_name],
                'is_text_dataset': True,
                'context_part': context_text_part,
                'target_part': target_text_part,
                'target_token_len': len(target_tokens),
            })
        else:
            random.shuffle(features)
            
            n_target = min(n_target_features, n_features - 1)
            context_features = features[:-n_target]
            target_features = features[-n_target:]
            
            samples_metadata.append({
                'original_idx': idx,
                'chosen_row': row.to_dict(),
                'features_order': features,
                'context_features': context_features,
                'target_features': target_features,
                'is_text_dataset': False,
            })
    
    print(f"Prepared metadata for {len(samples_metadata)} samples")
    return samples_metadata


def build_prompt_for_target(sample_meta, target_feature, generated_so_far):
    chosen_row = sample_meta['chosen_row']
    context_features = sample_meta['context_features']
    
    context_parts = []
    
    for f in context_features:
        context_parts.append(f" {f} is {str(chosen_row[f]).strip()} ")
    
    for f, v in generated_so_far.items():
        context_parts.append(f" {f} is {str(v).strip()} ")
    
    context_text = ",".join(context_parts) if context_parts else ""
    
    if context_text:
        prompt = f"{context_text}, {target_feature} is "
    else:
        prompt = f" {target_feature} is "
    
    return prompt


def generate_sequential_with_vllm(llm, samples_metadata, tokenizer, 
                                   generation_temperature, generation_top_p,
                                   max_model_len):
    print("\nStarting SEQUENTIAL generation with vLLM...")
    
    text_samples = [m for m in samples_metadata if m.get('is_text_dataset', False)]
    tabular_samples = [m for m in samples_metadata if not m.get('is_text_dataset', False)]
    
    generated_values_dict = {}
    generated_so_far = {i: {} for i in range(len(tabular_samples))}
    
    if text_samples:
        print(f"\nProcessing {len(text_samples)} text samples...")
        
        target_token_lens = [m.get('target_token_len', 100) for m in text_samples]
        avg_target_len = int(np.mean(target_token_lens))
        max_tokens = min(int(avg_target_len * 2.0), 300)
        max_prompt_tokens = max_model_len - max_tokens - 50
        
        sampling_params = SamplingParams(
            temperature=generation_temperature,
            top_p=generation_top_p,
            repetition_penalty=1.0,
            max_tokens=max_tokens,
            stop=["<|endoftext|>"],
            frequency_penalty=0.0
        )
        
        prompts = []
        truncation_count = 0
        for meta in text_samples:
            feature_name = meta['target_features'][0]
            context_part = meta['context_part']
            
            prefix = f" {feature_name} is "
            prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
            max_context_tokens = max_prompt_tokens - len(prefix_tokens)
            
            context_tokens = tokenizer.encode(context_part, add_special_tokens=False)
            if len(context_tokens) > max_context_tokens:
                truncated_context_tokens = context_tokens[-max_context_tokens:]
                truncated_context = tokenizer.decode(truncated_context_tokens, skip_special_tokens=False)
                prompt = f"{prefix}{truncated_context} "
                truncation_count += 1
            else:
                prompt = f"{prefix}{context_part} "
            
            prompts.append(prompt)
        
        if truncation_count > 0:
            print(f"  Truncated {truncation_count}/{len(text_samples)} text prompts (max_prompt_tokens={max_prompt_tokens})")
        
        outputs = llm.generate(prompts, sampling_params)
        
        for i, (output, meta) in enumerate(zip(outputs, text_samples)):
            generated_text = output.outputs[0].text.replace("</s>", "").strip()
            cleaned = re.sub(r'\s+', ' ', generated_text).strip()
            
            orig_meta_idx = samples_metadata.index(meta)
            generated_values_dict[orig_meta_idx] = {meta['target_features'][0]: cleaned}
        
        print(f"Text samples completed. Generated {len(text_samples)} continuations.")
    
    if tabular_samples:
        print(f"\nProcessing {len(tabular_samples)} tabular samples...")
        
        max_targets = max(len(m['target_features']) for m in tabular_samples)
        print(f"Max target features: {max_targets}")
        
        sampling_params = SamplingParams(
            temperature=generation_temperature,
            top_p=generation_top_p,
            repetition_penalty=1.0,
            max_tokens=50,
            stop=[",", "\n", "<|endoftext|>"],
            frequency_penalty=0.0
        )
        
        for round_idx in range(max_targets):
            print(f"\n  Round {round_idx + 1}/{max_targets}...")
            
            prompts = []
            sample_indices = []
            target_features_this_round = []
            
            for sample_idx, meta in enumerate(tabular_samples):
                target_features = meta['target_features']
                
                if round_idx >= len(target_features):
                    continue
                
                target_feature = target_features[round_idx]
                prompt = build_prompt_for_target(meta, target_feature, generated_so_far[sample_idx])
                
                prompts.append(prompt)
                sample_indices.append(sample_idx)
                target_features_this_round.append(target_feature)
            
            if not prompts:
                continue
            
            max_prompt_tokens = max_model_len - 100
            truncated_prompts = []
            for prompt in prompts:
                prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
                if len(prompt_tokens) > max_prompt_tokens:
                    truncated_tokens = prompt_tokens[-max_prompt_tokens:]
                    truncated_prompt = tokenizer.decode(truncated_tokens, skip_special_tokens=False)
                    truncated_prompts.append(truncated_prompt)
                else:
                    truncated_prompts.append(prompt)
            
            start_time = time.time()
            outputs = llm.generate(truncated_prompts, sampling_params)
            elapsed = time.time() - start_time
            print(f"    Generated {len(prompts)} values in {elapsed:.2f}s ({len(prompts)/elapsed:.1f} samples/sec)")
            
            for output, sample_idx, target_feature in zip(outputs, sample_indices, target_features_this_round):
                generated_text = output.outputs[0].text.replace("</s>", "").strip()
                
                cleaned = generated_text.strip().strip('"\'()[]{}')
                cleaned = re.sub(r'\s+', ' ', cleaned).strip()
                
                if " is " in cleaned:
                    cleaned = cleaned.split(" is ", 1)[-1].strip()
    
                generated_so_far[sample_idx][target_feature] = cleaned
        
        for sample_idx, meta in enumerate(tabular_samples):
            orig_meta_idx = samples_metadata.index(meta)
            generated_values_dict[orig_meta_idx] = generated_so_far[sample_idx].copy()
        
        print(f"\nTabular samples completed. Generated values for {len(tabular_samples)} samples.")
    
    return generated_values_dict


def construct_spin_pairs(samples_metadata, generated_values_dict):
    print("\nConstructing SPIN pairs...")
    
    spin_pairs = []
    
    for meta_idx, meta in enumerate(samples_metadata):
        chosen_row = meta['chosen_row']
        target_features = meta['target_features']
        features_order = meta['features_order']
        is_text = meta.get('is_text_dataset', False)
        
        generated_vals = generated_values_dict.get(meta_idx, {})
        rejected_row = chosen_row.copy()
        
        if is_text:
            feature_name = target_features[0]
            context_part = meta['context_part']
            generated_continuation = generated_vals.get(feature_name, "")
            
            if generated_continuation:
                rejected_text_full = context_part + " " + generated_continuation
            else:
                rejected_text_full = str(chosen_row[feature_name])
            
            rejected_row[feature_name] = rejected_text_full.strip()
        else:
            for target_feat in target_features:
                if target_feat in generated_vals and generated_vals[target_feat]:
                    rejected_row[target_feat] = generated_vals[target_feat]
        
        chosen_text = row_to_text(chosen_row, features_order)
        rejected_text = row_to_text(rejected_row, features_order)
        
        context_features = meta.get('context_features', [])
        if context_features:
            prompt_dict = {f: chosen_row[f] for f in context_features}
            prompt_text = row_to_text(prompt_dict, context_features)
        else:
            prompt_text = ""
        
        spin_pairs.append({
            'chosen_text': chosen_text,
            'rejected_text': rejected_text,
            'chosen_row': chosen_row,
            'rejected_row': rejected_row,
            'prompt': prompt_text,
            'original_idx': meta['original_idx'],
            'features_order': features_order,
            'context_features': context_features,
            'target_features': target_features,
            'generated_values': generated_vals,
            'generation_method': 'sequential_vllm'
        })
    
    print(f"Constructed {len(spin_pairs)} SPIN pairs")
    return spin_pairs


def main():
    args = get_args()
    
    base_path = Path(args.base_model_dir)
    
    if (base_path / 'config.json').exists():
        model_dir = base_path
        print(f"Using direct model path: {model_dir}")
    else:
        run_name = get_run_name(args)
        model_dir = base_path / 'models' / run_name
        print(f"Using derived model path: {model_dir}")
    
    print("Loading data...")
    X_train, X_test, y_train, y_test = load_data(args)
    print(f"Training samples: {len(X_train)}")
    print(f"Normal samples: {np.sum(y_train == 0)}, Anomaly samples: {np.sum(y_train == 1)}")
    
    print(f"Generating synthetic data from {np.sum(y_train == 0)} normal samples")
    
    max_length_dict = get_max_length_dict(args.dataset)
    print(f"Max length constraints: {max_length_dict}")
    
    if not model_dir.exists():
        raise ValueError(f"Model directory not found at {model_dir}. "
                        f"Make sure train_anollm.py saved the model in HuggingFace format.")
    
    required_files = ['config.json', 'tokenizer_config.json']
    for file in required_files:
        if not (model_dir / file).exists():
            raise ValueError(f"Required file {file} not found in {model_dir}. "
                           f"Model must be in HuggingFace format for vLLM.")
    
    tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
    tokenizer.pad_token = tokenizer.eos_token
    
    samples_metadata = prepare_samples_metadata(
        X_train, 
        y_train,
        n_target_features=args.n_target_features,
        tokenizer=tokenizer,
        max_model_len=args.max_model_len
    )
    
    print(f"\nInitializing vLLM with model: {model_dir}")
    llm = LLM(
        model=str(model_dir),
        tensor_parallel_size=1,
        dtype="bfloat16",
        max_model_len=args.max_model_len,
        gpu_memory_utilization=args.gpu_memory_utilization,
        trust_remote_code=True,
        enforce_eager=True,
        disable_custom_all_reduce=True,
    )
    
    generated_values_dict = generate_sequential_with_vllm(
        llm=llm,
        samples_metadata=samples_metadata,
        tokenizer=tokenizer,
        generation_temperature=args.generation_temperature,
        generation_top_p=args.generation_top_p,
        max_model_len=args.max_model_len
    )
    
    spin_pairs = construct_spin_pairs(samples_metadata, generated_values_dict)
    
    output_dir = Path(args.output_dir)
    iteration_dir = output_dir / f"iter{args.iteration}"
    os.makedirs(iteration_dir, exist_ok=True)
    
    output_path = iteration_dir / f'spin_dataset_iter{args.iteration}.pkl'
    with open(output_path, 'wb') as f:
        pkl.dump(spin_pairs, f)
    
    print(f"\nDataset saved to {output_path}")
    print(f"SPIN iteration {args.iteration} generation completed!")
    
    info = {
        'iteration': args.iteration,
        'n_pairs': len(spin_pairs),
        'n_original_samples': len(X_train),
        'generation_method': 'sequential_vllm',
        'model_path': str(model_dir),
        'dataset': args.dataset,
        'setting': args.setting,
        'split_idx': args.split_idx,
        'use_normal_generation': args.use_normal_generation,
        'temperature': args.generation_temperature,
        'top_p': args.generation_top_p,
        'n_target_features': args.n_target_features,
    }
    
    info_path = iteration_dir / 'generation_info.json'
    with open(info_path, 'w') as f:
        json.dump(info, f, indent=2)
    
    print("\nGeneration info:")
    for key, value in info.items():
        print(f"  {key}: {value}")


if __name__ == '__main__':
    multiprocessing.set_start_method('spawn', force=True)
    main()
