use anyhow::{Context, Result};
use clap::Parser;
use crossbeam_channel::{bounded, Sender, Receiver};
use serde::{Deserialize, Serialize};
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::thread;
use tokenizers::Tokenizer;

#[derive(Parser, Debug)]
#[command(author, version, about = "High-performance tokenizer worker for PR synthesis pipeline", long_about = None)]
struct Args {
    /// Tokenizer model name or path to tokenizer.json
    #[arg(short = 'm', long, default_value = "Qwen/Qwen2.5-Coder-32B-Instruct")]
    model: String,

    /// Number of worker threads (0 = num_cpus)
    #[arg(short = 'w', long, default_value = "128")]
    workers: usize,

    /// Number of threads per tokenizer instance (0 = all share one, 1 = one per thread)
    #[arg(short = 't', long, default_value = "16")]
    threads_per_tokenizer: usize,

    /// Request channel buffer size (like Go channel buffer)
    #[arg(long, default_value = "1000")]
    request_buffer: usize,

    /// Response batch size (collect this many results before writing)
    #[arg(long, default_value = "100")]
    response_batch_size: usize,

    /// Maximum message size in bytes (default: 100MB)
    #[arg(long, default_value = "104857600")]
    max_message_size: usize,
}

/// IPC Protocol: Length-prefixed MessagePack binary protocol
///
/// Request format:
/// [4 bytes: message length (u32 big-endian)][MessagePack encoded TokenizeRequest]
///
/// Response format:
/// [4 bytes: message length (u32 big-endian)][MessagePack encoded TokenizeResponse]
///
/// Requests contain individual PRs, responses can be batched freely

#[derive(Debug, Serialize, Deserialize)]
struct TokenizeRequest {
    command: String,
    prs: Vec<PRText>,
    max_tokens: i32,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct PRText {
    repo_id: i64,
    repo_name: String,
    pr_id: i64,
    text: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct TokenizeResponse {
    status: String,
    results: Vec<TokenizedResult>,
    error: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct TokenizedResult {
    repo_id: i64,
    repo_name: String,
    pr_id: i64,
    token_ids: Option<Vec<i32>>,
    token_count: i32,
    byte_size: i32,
    discarded: bool,
}

fn load_tokenizer(model: &str) -> Result<Tokenizer> {
    // If model is a file path, load directly
    if std::path::Path::new(model).exists() {
        eprintln!("[RUST] Loading tokenizer from file: {}", model);
        return Tokenizer::from_file(model)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer from file: {}", e));
    }

    // Otherwise, try to load from HuggingFace cache
    let hf_home = std::env::var("HF_HOME")
        .or_else(|_| std::env::var("HOME").map(|h| format!("{}/.cache/huggingface", h)))
        .unwrap_or_else(|_| ".cache/huggingface".to_string());

    let model_path = model.replace('/', "--");
    let model_root = format!("{}/hub/models--{}", hf_home, model_path);

    // Try refs/main first
    let refs_main = format!("{}/refs/main", model_root);
    if let Ok(rev) = std::fs::read_to_string(&refs_main) {
        let rev = rev.trim();
        let tokenizer_path = format!("{}/snapshots/{}/tokenizer.json", model_root, rev);
        if std::path::Path::new(&tokenizer_path).exists() {
            eprintln!("[RUST] Using tokenizer from: {}", tokenizer_path);
            return Tokenizer::from_file(tokenizer_path)
                .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e));
        }
    }

    // Fallback: scan snapshots
    let snapshots_dir = format!("{}/snapshots", model_root);
    if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
        for entry in entries.flatten() {
            let tokenizer_path = entry.path().join("tokenizer.json");
            if tokenizer_path.exists() {
                eprintln!("[RUST] Using tokenizer from: {}", tokenizer_path.display());
                return Tokenizer::from_file(tokenizer_path)
                    .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e));
            }
        }
    }

    anyhow::bail!(
        "Tokenizer '{}' not found in HuggingFace cache. Please download it first using Python:\n\
         from transformers import AutoTokenizer\n\
         AutoTokenizer.from_pretrained('{}')",
        model, model
    )
}

fn read_message<T: for<'de> Deserialize<'de>>(reader: &mut impl Read, max_size: usize) -> Result<T> {
    // Read 4-byte length prefix
    let mut len_buf = [0u8; 4];
    reader.read_exact(&mut len_buf).context("Failed to read message length")?;
    let msg_len = u32::from_be_bytes(len_buf) as usize;

    // Check message size limit
    if msg_len > max_size {
        anyhow::bail!("Message size {} bytes exceeds limit {} bytes", msg_len, max_size);
    }

    // Read message body
    let mut msg_buf = vec![0u8; msg_len];
    reader.read_exact(&mut msg_buf).context("Failed to read message body")?;

    // Deserialize using MessagePack
    // Note: rmp_serde has default limits that might be hit with very large messages
    match rmp_serde::from_slice::<T>(&msg_buf) {
        Ok(msg) => Ok(msg),
        Err(e) => {
            let preview_len = msg_len.min(64);
            let preview = &msg_buf[..preview_len];
            
            // Try to extract more details from the error
            let error_detail = format!("{:?}", e);
            
            anyhow::bail!(
                "Failed to deserialize message:\n\
                 - Length: {} bytes ({:.2} MB)\n\
                 - Error: {}\n\
                 - First {} bytes: {:02x?}\n\
                 - This may indicate a message that's too large or contains invalid data",
                msg_len, msg_len as f64 / (1024.0 * 1024.0), error_detail, preview_len, preview
            )
        }
    }
}

fn write_message<T: Serialize>(writer: &mut impl Write, msg: &T) -> Result<()> {
    // Serialize using MessagePack
    let msg_bytes = rmp_serde::to_vec(msg).context("Failed to serialize message")?;
    
    // Write 4-byte length prefix
    let len = msg_bytes.len() as u32;
    writer.write_all(&len.to_be_bytes()).context("Failed to write message length")?;
    
    // Write message body
    writer.write_all(&msg_bytes).context("Failed to write message body")?;
    writer.flush().context("Failed to flush writer")?;
    
    Ok(())
}

// Job represents a single PR to tokenize
#[derive(Debug, Clone)]
struct Job {
    pr: PRText,
    max_tokens: i32,
}

fn main() -> Result<()> {
    let args = Args::parse();

    // Disable tokenizer's internal parallelism to avoid nested parallelism
    std::env::set_var("TOKENIZERS_PARALLELISM", "false");

    // Set rayon thread pool size
    let num_workers = if args.workers == 0 {
        num_cpus::get()
    } else {
        args.workers
    };

    rayon::ThreadPoolBuilder::new()
        .num_threads(num_workers)
        .build_global()
        .context("Failed to build thread pool")?;

    eprintln!("[RUST] Tokenizer Worker starting...");
    eprintln!("[RUST] Configuration:");
    eprintln!("[RUST]   Model: {}", args.model);
    eprintln!("[RUST]   Workers: {}", num_workers);
    eprintln!("[RUST]   Threads per tokenizer: {}",
        if args.threads_per_tokenizer == 0 {
            "all share one".to_string()
        } else if args.threads_per_tokenizer == 1 {
            "one per thread".to_string()
        } else {
            format!("{}", args.threads_per_tokenizer)
        });
    eprintln!("[RUST]   Request buffer: {}", args.request_buffer);
    eprintln!("[RUST]   Response batch size: {}", args.response_batch_size);
    eprintln!("[RUST]   Max message size: {} MB", args.max_message_size / (1024 * 1024));
    
    let tokenizer = load_tokenizer(&args.model)?;
    let tokenizer_json = Arc::new(
        tokenizer.to_string(false)
            .map_err(|e| anyhow::anyhow!("Failed to serialize tokenizer: {}", e))?
    );
    
    eprintln!("[RUST] Tokenizer loaded, ready for requests");
    eprintln!("[RUST] Using async channel-based architecture");

    // Create bounded channels (like Go channel buffers) using crossbeam
    let (job_tx, job_rx): (Sender<Job>, Receiver<Job>) = bounded(args.request_buffer);
    let (result_tx, result_rx): (Sender<TokenizedResult>, Receiver<TokenizedResult>) = bounded(args.response_batch_size * 2);

    // Spawn reader thread (reads requests from stdin, puts jobs in channel)
    let max_msg_size = args.max_message_size;
    let reader_handle = thread::spawn(move || {
        reader_thread(job_tx, max_msg_size)
    });

    // Spawn worker pool (processes jobs from channel, sends results to result channel)
    let tokenizer_json_clone = Arc::clone(&tokenizer_json);
    let threads_per_tokenizer = args.threads_per_tokenizer;
    let worker_handle = thread::spawn(move || {
        worker_thread(job_rx, result_tx, tokenizer_json_clone, threads_per_tokenizer)
    });

    // Spawn writer thread (collects results, batches, writes to stdout)
    let response_batch_size = args.response_batch_size;
    let writer_handle = thread::spawn(move || {
        writer_thread(result_rx, response_batch_size)
    });

    // Wait for all threads
    if let Err(e) = reader_handle.join() {
        eprintln!("[RUST] Reader thread panicked: {:?}", e);
    }
    if let Err(e) = worker_handle.join() {
        eprintln!("[RUST] Worker thread panicked: {:?}", e);
    }
    if let Err(e) = writer_handle.join() {
        eprintln!("[RUST] Writer thread panicked: {:?}", e);
    }

    eprintln!("[RUST] Tokenizer Worker shutting down");
    Ok(())
}

// Reader thread: reads requests from stdin and puts jobs in channel
fn reader_thread(job_tx: Sender<Job>, max_message_size: usize) -> Result<()> {
    let stdin = io::stdin();
    let mut reader = stdin.lock();
    let mut total_jobs = 0;

    loop {
        let request: TokenizeRequest = match read_message(&mut reader, max_message_size) {
            Ok(req) => req,
            Err(e) => {
                let err_str = e.to_string();
                // Normal shutdown conditions - stdin closed (EOF)
                if err_str.contains("unexpected end of file")
                    || err_str.contains("Failed to read message length")
                    || err_str.contains("failed to fill whole buffer") {
                    eprintln!("[RUST] Reader: stdin closed after {} jobs, shutting down", total_jobs);
                    break;
                }
                // Message too large - log warning and continue reading
                if err_str.contains("exceeds limit") {
                    eprintln!("[RUST] Warning: {}, skipping message", err_str);
                    continue;
                }
                // Other errors (including deserialization) - log and continue
                eprintln!("[RUST] Error reading request: {}, continuing...", e);
                continue;
            }
        };

        // Put each PR as a separate job in the channel
        for pr in request.prs {
            let job = Job {
                pr,
                max_tokens: request.max_tokens,
            };
            
            // This blocks if channel is full (backpressure)
            if job_tx.send(job).is_err() {
                // Channel closed, shutdown
                eprintln!("[RUST] Reader: job channel closed after {} jobs, shutting down", total_jobs);
                break;
            }
            total_jobs += 1;
        }
    }

    // Drop sender to signal workers to shutdown
    drop(job_tx);
    eprintln!("[RUST] Reader thread exiting after processing {} total jobs", total_jobs);
    Ok(())
}

// Worker thread: processes jobs using worker pool pattern (Mode 3 from bench_data)
// N threads share one tokenizer for best performance
// Processes jobs as they arrive in streaming fashion
fn worker_thread(
    job_rx: Receiver<Job>,
    result_tx: Sender<TokenizedResult>,
    tokenizer_json: Arc<String>,
    threads_per_tokenizer: usize,
) -> Result<()> {
    let num_workers = rayon::current_num_threads();
    
    // Calculate number of tokenizer groups
    let num_groups = if threads_per_tokenizer == 0 {
        1 // All threads share one tokenizer
    } else if threads_per_tokenizer == 1 {
        num_workers // One tokenizer per thread
    } else {
        (num_workers + threads_per_tokenizer - 1) / threads_per_tokenizer
    };
    
    eprintln!("[RUST] Worker pool: {} workers, {} tokenizer groups", num_workers, num_groups);

    // Pre-create tokenizers for each group
    let tokenizers: Vec<Arc<Tokenizer>> = (0..num_groups)
        .map(|_| {
            Arc::new(
                Tokenizer::from_bytes(tokenizer_json.as_bytes())
                    .expect("Failed to create tokenizer")
            )
        })
        .collect();
    
    let tokenizers = Arc::new(tokenizers);

    // Use crossbeam's scope to spawn worker threads
    crossbeam::scope(|scope| {
        // Spawn worker threads that continuously pull from the channel
        for worker_id in 0..num_workers {
            let job_rx = job_rx.clone();
            let result_tx = result_tx.clone();
            let tokenizers = Arc::clone(&tokenizers);
            
            scope.spawn(move |_| {
                // Determine which tokenizer group this worker belongs to
                let group_id = if threads_per_tokenizer == 0 {
                    0 // All threads share one tokenizer
                } else if threads_per_tokenizer == 1 {
                    worker_id % tokenizers.len() // One tokenizer per thread
                } else {
                    (worker_id / threads_per_tokenizer).min(tokenizers.len() - 1)
                };
                let tok = &tokenizers[group_id];
                
                // Continuously process jobs from the channel
                for job in job_rx.iter() {
                    let result = match tok.encode(job.pr.text.as_str(), false) {
                        Ok(encoding) => {
                            let token_ids: Vec<u32> = encoding.get_ids().to_vec();
                            let token_count = token_ids.len() as i32;
                            let byte_size = job.pr.text.len() as i32;
                            let discarded = token_count > job.max_tokens;

                            TokenizedResult {
                                repo_id: job.pr.repo_id,
                                repo_name: job.pr.repo_name.clone(),
                                pr_id: job.pr.pr_id,
                                token_ids: if discarded {
                                    None
                                } else {
                                    Some(token_ids.iter().map(|&id| id as i32).collect())
                                },
                                token_count,
                                byte_size,
                                discarded,
                            }
                        }
                        Err(e) => {
                            eprintln!("[RUST] Failed to tokenize PR {}/{}: {}", job.pr.repo_name, job.pr.pr_id, e);
                            TokenizedResult {
                                repo_id: job.pr.repo_id,
                                repo_name: job.pr.repo_name.clone(),
                                pr_id: job.pr.pr_id,
                                token_ids: None,
                                token_count: 0,
                                byte_size: job.pr.text.len() as i32,
                                discarded: true,
                            }
                        }
                    };
                    
                    // Send result to writer (break if channel closed)
                    if result_tx.send(result).is_err() {
                        break;
                    }
                }
            });
        }
    }).expect("Worker scope failed");

    // Drop sender to signal writer to shutdown
    drop(result_tx);
    eprintln!("[RUST] Worker thread completed, all workers finished");
    Ok(())
}

// Writer thread: collects results and writes batched responses
fn writer_thread(result_rx: Receiver<TokenizedResult>, batch_size: usize) -> Result<()> {
    let stdout = io::stdout();
    let mut writer = stdout.lock();

    let mut batch = Vec::new();

    for result in result_rx {
        batch.push(result);
        
        if batch.len() >= batch_size {
            let response = TokenizeResponse {
                status: "success".to_string(),
                results: batch.clone(),
                error: None,
            };
            write_message(&mut writer, &response)?;
            batch.clear();
        }
    }

    // Write remaining results
    if !batch.is_empty() {
        let response = TokenizeResponse {
            status: "success".to_string(),
            results: batch,
            error: None,
        };
        write_message(&mut writer, &response)?;
    }

    Ok(())
}