use anyhow::{Context, Result};
use arrow::array::{Array, Int64Array, StringArray};
use arrow::record_batch::RecordBatch;
use clap::Parser;
use indicatif::{ProgressBar, ProgressStyle};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use rayon::prelude::*;
use std::fs::File;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokenizers::Tokenizer;

#[derive(Parser, Debug)]
#[command(author, version, about = "Tokenization benchmark for PR text data", long_about = None)]
struct Args {
    /// Path to the task4 output directory containing parquet files
    #[arg(short, long)]
    input_dir: Option<PathBuf>,

    /// Number of parquet files to read (0 = all)
    #[arg(short = 'n', long, default_value = "0")]
    num_files: usize,

    /// Number of worker threads (0 = num_cpus)
    #[arg(short = 'w', long, default_value = "0")]
    workers: usize,

    /// Tokenizer model path or HuggingFace model ID
    #[arg(short = 't', long, default_value = "Qwen/Qwen2.5-Coder-32B-Instruct")]
    tokenizer_model: String,

    /// Batch size for batch tokenization test
    #[arg(short = 'b', long, default_value = "32")]
    batch_size: usize,

    /// Enable batch tokenization comparison
    #[arg(long)]
    compare_batch: bool,

    /// Number of threads sharing one tokenizer instance (0 = all share one, 1 = one per thread)
    #[arg(long, default_value = "1")]
    threads_per_tokenizer: usize,

    /// Use synthetic data instead of parquet files
    #[arg(long)]
    synthetic: bool,

    /// Number of synthetic strings to generate
    #[arg(long, default_value = "10000")]
    synthetic_count: usize,

    /// Length of each synthetic string in bytes
    #[arg(long, default_value = "204800")]
    synthetic_length: usize,
}

#[allow(dead_code)]
struct RenderedPRText {
    repo_id: i64,
    repo_name: String,
    pr_id: i64,
    text: String,
}

struct BenchmarkStats {
    prs_processed: AtomicUsize,
    tokens_generated: AtomicU64,
    bytes_processed: AtomicU64,
    start_time: Instant,
}

impl BenchmarkStats {
    fn new() -> Self {
        Self {
            prs_processed: AtomicUsize::new(0),
            tokens_generated: AtomicU64::new(0),
            bytes_processed: AtomicU64::new(0),
            start_time: Instant::now(),
        }
    }

    fn add_pr(&self, token_count: usize, byte_count: usize) {
        self.prs_processed.fetch_add(1, Ordering::Relaxed);
        self.tokens_generated
            .fetch_add(token_count as u64, Ordering::Relaxed);
        self.bytes_processed
            .fetch_add(byte_count as u64, Ordering::Relaxed);
    }

    fn report(&self) {
        let elapsed = self.start_time.elapsed().as_secs_f64();
        let prs = self.prs_processed.load(Ordering::Relaxed);
        let tokens = self.tokens_generated.load(Ordering::Relaxed);
        let bytes = self.bytes_processed.load(Ordering::Relaxed);

        let pr_rate = prs as f64 / elapsed;
        let token_rate = tokens as f64 / elapsed;
        let mb_rate = (bytes as f64 / (1024.0 * 1024.0)) / elapsed;
        let gb_total = bytes as f64 / (1024.0 * 1024.0 * 1024.0);

        println!(
            "[THROUGHPUT] PRs: {} ({:.1}/s) | Tokens: {} ({:.1}/s) | Data: {:.2} GB ({:.2} MB/s) | Elapsed: {:.1}s",
            prs, pr_rate, tokens, token_rate, gb_total, mb_rate, elapsed
        );
    }
}

fn read_parquet_file(path: &PathBuf) -> Result<Vec<RenderedPRText>> {
    let file = File::open(path).context("Failed to open parquet file")?;
    let builder = ParquetRecordBatchReaderBuilder::try_new(file)
        .context("Failed to create parquet reader")?;
    let mut reader = builder.build().context("Failed to build reader")?;

    let mut results = Vec::new();

    while let Some(batch) = reader.next() {
        let batch = batch.context("Failed to read batch")?;
        results.extend(parse_batch(&batch)?);
    }

    Ok(results)
}

fn parse_batch(batch: &RecordBatch) -> Result<Vec<RenderedPRText>> {
    use arrow::array::BinaryArray;
    
    let repo_id = batch
        .column_by_name("repo_id")
        .context("Missing repo_id column")?
        .as_any()
        .downcast_ref::<Int64Array>()
        .context("repo_id is not Int64Array")?;

    let repo_name = batch
        .column_by_name("repo_name")
        .context("Missing repo_name column")?
        .as_any()
        .downcast_ref::<StringArray>()
        .context("repo_name is not StringArray")?;

    let pr_id = batch
        .column_by_name("pr_id")
        .context("Missing pr_id column")?
        .as_any()
        .downcast_ref::<Int64Array>()
        .context("pr_id is not Int64Array")?;

    let text_column = batch
        .column_by_name("text")
        .context("Missing text column")?;

    let mut results = Vec::with_capacity(batch.num_rows());
    
    // Try StringArray first, fall back to BinaryArray if UTF-8 validation fails
    if let Some(text) = text_column.as_any().downcast_ref::<StringArray>() {
        for i in 0..batch.num_rows() {
            if !text.is_null(i) {
                results.push(RenderedPRText {
                    repo_id: repo_id.value(i),
                    repo_name: repo_name.value(i).to_string(),
                    pr_id: pr_id.value(i),
                    text: text.value(i).to_string(),
                });
            }
        }
    } else if let Some(text) = text_column.as_any().downcast_ref::<BinaryArray>() {
        // Handle binary data with lossy UTF-8 conversion
        for i in 0..batch.num_rows() {
            if !text.is_null(i) {
                let bytes = text.value(i);
                let text_str = String::from_utf8_lossy(bytes).to_string();
                results.push(RenderedPRText {
                    repo_id: repo_id.value(i),
                    repo_name: repo_name.value(i).to_string(),
                    pr_id: pr_id.value(i),
                    text: text_str,
                });
            }
        }
    } else {
        anyhow::bail!("text column is neither StringArray nor BinaryArray");
    }

    Ok(results)
}

fn load_data(input_dir: &PathBuf, num_files: usize) -> Result<Vec<RenderedPRText>> {
    let mut parquet_files: Vec<PathBuf> = std::fs::read_dir(input_dir)
        .context("Failed to read input directory")?
        .filter_map(|entry| {
            let entry = entry.ok()?;
            let path = entry.path();
            if path.extension()? == "parquet" {
                Some(path)
            } else {
                None
            }
        })
        .collect();

    parquet_files.sort();

    if num_files > 0 && num_files < parquet_files.len() {
        parquet_files.truncate(num_files);
    }

    println!(
        "[INFO] Loading {} parquet files from {}",
        parquet_files.len(),
        input_dir.display()
    );

    let pb = ProgressBar::new(parquet_files.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} files")
            .unwrap()
            .progress_chars("=>-"),
    );

    let mut all_data = Vec::new();
    for file in parquet_files {
        let data = read_parquet_file(&file)?;
        all_data.extend(data);
        pb.inc(1);
    }
    pb.finish_with_message("Loading complete");

    println!("[INFO] Loaded {} PR texts", all_data.len());
    Ok(all_data)
}

fn generate_synthetic_data(count: usize, length: usize) -> Vec<RenderedPRText> {
    use rand::Rng;
    
    println!(
        "[INFO] Generating {} synthetic strings of {} bytes each",
        count, length
    );

    let pb = ProgressBar::new(count as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} strings")
            .unwrap()
            .progress_chars("=>-"),
    );

    let data: Vec<RenderedPRText> = (0..count)
        .map(|i| {
            let mut rng = rand::thread_rng();
            let text: String = (0..length)
                .map(|_| {
                    let idx = rng.gen_range(0..52);
                    if idx < 26 {
                        (b'a' + idx) as char
                    } else {
                        (b'A' + idx - 26) as char
                    }
                })
                .collect();
            
            pb.inc(1);
            
            RenderedPRText {
                repo_id: i as i64,
                repo_name: format!("synthetic/repo-{}", i),
                pr_id: i as i64,
                text,
            }
        })
        .collect();

    pb.finish_with_message("Generation complete");
    
    let total_bytes: usize = data.iter().map(|pr| pr.text.len()).sum();
    let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
    println!(
        "[INFO] Generated {} strings, total size: {:.2} MB",
        data.len(),
        total_mb
    );
    
    data
}

#[allow(dead_code)]
fn benchmark_single_threaded(
    tokenizer: &Tokenizer,
    data: &[RenderedPRText],
    stats: &BenchmarkStats,
) -> Result<()> {
    println!("\n=== Single-threaded Tokenization ===");

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} PRs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    for pr in data {
        let encoding = tokenizer
            .encode(pr.text.as_str(), false)
            .map_err(|e| anyhow::anyhow!("Failed to encode text: {}", e))?;
        let token_count = encoding.get_ids().len();
        let byte_count = pr.text.len();

        stats.add_pr(token_count, byte_count);
        pb.inc(1);
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn benchmark_worker_pool(
    tokenizer: &Tokenizer,
    data: &[RenderedPRText],
    stats: &BenchmarkStats,
    num_workers: usize,
    threads_per_tokenizer: usize,
) -> Result<()> {
    let sharing_mode = if threads_per_tokenizer == 0 {
        "all share one".to_string()
    } else if threads_per_tokenizer == 1 {
        "one per thread".to_string()
    } else {
        format!("{} threads per tokenizer", threads_per_tokenizer)
    };
    
    println!(
        "\n=== Worker Pool Tokenization ({} workers, {}) ===",
        num_workers, sharing_mode
    );

    let pool = rayon::ThreadPoolBuilder::new()
        .num_threads(num_workers)
        .build()
        .context("Failed to build thread pool")?;

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} PRs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    if threads_per_tokenizer == 0 {
        // Mode 1: All threads share one tokenizer
        pool.install(|| {
            data.par_iter().for_each(|pr| {
                if let Ok(encoding) = tokenizer.encode(pr.text.as_str(), false) {
                    let token_count = encoding.get_ids().len();
                    let byte_count = pr.text.len();
                    stats.add_pr(token_count, byte_count);
                    pb.inc(1);
                }
            });
        });
    } else if threads_per_tokenizer == 1 {
        // Mode 2: One tokenizer per thread
        let tokenizer_json = tokenizer.to_string(false)
            .map_err(|e| anyhow::anyhow!("Failed to serialize tokenizer: {}", e))?;
        let tokenizer_json = Arc::new(tokenizer_json);

        pool.install(|| {
            data.par_iter().for_each(|pr| {
                thread_local! {
                    static TOKENIZER: std::cell::RefCell<Option<Tokenizer>> = std::cell::RefCell::new(None);
                }
                
                TOKENIZER.with(|tok_cell| {
                    let mut tok_opt = tok_cell.borrow_mut();
                    if tok_opt.is_none() {
                        match Tokenizer::from_bytes(tokenizer_json.as_bytes()) {
                            Ok(tok) => *tok_opt = Some(tok),
                            Err(e) => {
                                eprintln!("Failed to create tokenizer for thread: {}", e);
                                return;
                            }
                        }
                    }
                    
                    if let Some(tok) = tok_opt.as_ref() {
                        if let Ok(encoding) = tok.encode(pr.text.as_str(), false) {
                            let token_count = encoding.get_ids().len();
                            let byte_count = pr.text.len();
                            stats.add_pr(token_count, byte_count);
                            pb.inc(1);
                        }
                    }
                });
            });
        });
    } else {
        // Mode 3: N threads share one tokenizer (best performance)
        let tokenizer_json = tokenizer.to_string(false)
            .map_err(|e| anyhow::anyhow!("Failed to serialize tokenizer: {}", e))?;
        let tokenizer_json = Arc::new(tokenizer_json);
        
        // Calculate number of tokenizer groups
        let num_groups = (num_workers + threads_per_tokenizer - 1) / threads_per_tokenizer;
        println!("[INFO] Creating {} tokenizer groups", num_groups);
        
        // Pre-create tokenizers for each group
        let tokenizers: Vec<Arc<Tokenizer>> = (0..num_groups)
            .map(|_| {
                Arc::new(
                    Tokenizer::from_bytes(tokenizer_json.as_bytes())
                        .expect("Failed to create tokenizer")
                )
            })
            .collect();
        
        let tokenizers = Arc::new(tokenizers);
        let counter = Arc::new(AtomicUsize::new(0));

        pool.install(|| {
            data.par_iter().for_each(|pr| {
                thread_local! {
                    static THREAD_ID: std::cell::RefCell<Option<usize>> = std::cell::RefCell::new(None);
                }
                
                THREAD_ID.with(|id_cell| {
                    let mut id_opt = id_cell.borrow_mut();
                    if id_opt.is_none() {
                        // Assign this thread to a group
                        let thread_id = counter.fetch_add(1, Ordering::Relaxed);
                        *id_opt = Some(thread_id);
                    }
                    
                    let thread_id = id_opt.unwrap();
                    let group_id = thread_id / threads_per_tokenizer;
                    let tok = &tokenizers[group_id.min(tokenizers.len() - 1)];
                    
                    if let Ok(encoding) = tok.encode(pr.text.as_str(), false) {
                        let token_count = encoding.get_ids().len();
                        let byte_count = pr.text.len();
                        stats.add_pr(token_count, byte_count);
                        pb.inc(1);
                    }
                });
            });
        });
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn benchmark_batch_tokenization(
    tokenizer: &Tokenizer,
    data: &[RenderedPRText],
    stats: &BenchmarkStats,
    batch_size: usize,
) -> Result<()> {
    println!(
        "\n=== Batch Tokenization (batch_size={}) ===",
        batch_size
    );

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} PRs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    for chunk in data.chunks(batch_size) {
        let texts: Vec<&str> = chunk.iter().map(|pr| pr.text.as_str()).collect();

        let encodings = tokenizer
            .encode_batch(texts, false)
            .map_err(|e| anyhow::anyhow!("Failed to encode batch: {}", e))?;

        for (i, encoding) in encodings.iter().enumerate() {
            let token_count = encoding.get_ids().len();
            let byte_count = chunk[i].text.len();
            stats.add_pr(token_count, byte_count);
            pb.inc(1);
        }
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn main() -> Result<()> {
    // Disable tokenizer's internal parallelism to avoid nested parallelism
    // The tokenizer library uses Rayon internally, which conflicts with our worker pool
    // This line is commented out because it doesn't have an obvious effect in my experiments
    // std::env::set_var("TOKENIZERS_PARALLELISM", "false");
    
    let args = Args::parse();

    // Set number of workers
    let num_workers = if args.workers == 0 {
        num_cpus::get()
    } else {
        args.workers
    };

    println!("[INFO] Tokenization Benchmark");
    println!("[INFO] Workers: {}", num_workers);
    println!("[INFO] Batch size: {}", args.batch_size);

    // Load tokenizer
    println!("[INFO] Loading tokenizer: {}", args.tokenizer_model);
    let tokenizer = if std::path::Path::new(&args.tokenizer_model).exists() {
        Tokenizer::from_file(&args.tokenizer_model)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer from file: {}", e))?
    } else {
        // Try to load from HuggingFace cache
        let cache_dir = std::env::var("HF_HOME")
            .or_else(|_| std::env::var("HOME").map(|h| format!("{}/.cache/huggingface", h)))
            .unwrap_or_else(|_| ".cache/huggingface".to_string());
        
        let model_path = args.tokenizer_model.replace("/", "--");
        let tokenizer_path = format!("{}/hub/models--{}/snapshots", cache_dir, model_path);
        
        // Find the tokenizer.json in snapshots
        let mut found_tokenizer = None;
        if let Ok(entries) = std::fs::read_dir(&tokenizer_path) {
            for entry in entries.flatten() {
                let snapshot_path = entry.path().join("tokenizer.json");
                if snapshot_path.exists() {
                    found_tokenizer = Some(snapshot_path);
                    break;
                }
            }
        }
        
        if let Some(path) = found_tokenizer {
            Tokenizer::from_file(path)
                .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?
        } else {
            anyhow::bail!("Tokenizer not found. Please download it first using Python or provide a path to tokenizer.json");
        }
    };
    println!("[INFO] Tokenizer loaded successfully");

    // Load data
    let data = if args.synthetic {
        generate_synthetic_data(args.synthetic_count, args.synthetic_length)
    } else {
        let input_dir = args.input_dir.ok_or_else(|| {
            anyhow::anyhow!("--input-dir is required when not using --synthetic mode")
        })?;
        load_data(&input_dir, args.num_files)?
    };

    if data.is_empty() {
        anyhow::bail!("No data loaded");
    }

    // Benchmark 1: Worker pool (main benchmark)
    let stats = BenchmarkStats::new();
    benchmark_worker_pool(&tokenizer, &data, &stats, num_workers, args.threads_per_tokenizer)?;

    // std::env::set_var("TOKENIZERS_PARALLELISM", "true");

    // Benchmark 2: Batch tokenization (if requested)
    if args.compare_batch {
        let worker_elapsed = stats.start_time.elapsed().as_secs_f64();
        let stats_batch = BenchmarkStats::new();
        benchmark_batch_tokenization(&tokenizer, &data, &stats_batch, args.batch_size)?;

        // Compare results
        println!("\n=== Performance Comparison ===");
        let batch_elapsed = stats_batch.start_time.elapsed().as_secs_f64();
        let speedup = worker_elapsed / batch_elapsed;

        println!(
            "Worker Pool: {:.1}s | Batch: {:.1}s | Speedup: {:.2}x",
            worker_elapsed, batch_elapsed, speedup
        );
    }

    println!("\n[INFO] Benchmark complete");
    Ok(())
}
