use anyhow::{Context, Result};
use csv::ReaderBuilder;
use nalgebra::DMatrix;
use rayon::prelude::*;
use std::env;
use std::fs::File;
use std::process;

fn main() -> Result<()> {
    let args: Vec<String> = env::args().collect();
    if args.len() < 2 {
        eprintln!("Usage: {} <filename> <lambda>", args[0]);
        process::exit(1);
    }

    let filename = &args[1];
    let data = read_data(filename)?;
    let lambda: f64 = args[2].parse().unwrap();

    let g = run(&data, lambda);

    g.output();

    Ok(())
}

fn read_data(data_file: &String) -> Result<DMatrix<f64>> {
    let file = File::open(data_file).context("failed to open data file")?;
    let mut rdr = ReaderBuilder::new().from_reader(file);
    let mut data_slice: Vec<f64> = Vec::new();
    let mut rows = 0;
    for result in rdr.records() {
        let record = result?;
        let mut row: Vec<f64> = record
            .iter()
            .map(|s| s.parse::<f64>())
            .collect::<Result<Vec<_>, _>>()?;
        data_slice.append(&mut row);
        rows += 1;
    }
    Ok(DMatrix::from_row_slice(
        rows,
        data_slice.len() / rows,
        &data_slice,
    ))
}

fn run(data: &DMatrix<f64>, penalty: f64) -> Dag {
    let p = data.ncols();
    let score = Bic::new(data, penalty);

    // Step 1: Compute local scores
    let local_scores: Vec<_> = (0..p)
        .into_par_iter()
        .map(|v| compute_local_scores(v, p, &score))
        .collect();

    // Step 2: Find best parents for all vertices and subsets
    let best_parents: Vec<_> = (0..p)
        .map(|v| compute_best_parents(&local_scores[v], p))
        .collect();

    // Step 3: Find the best sink for all subsets
    let best_sinks = compute_best_sinks(&best_parents, &local_scores, p);

    // Step 4: Find the best ordering of vertices
    let best_ordering = compute_best_ordering(&best_sinks, p);

    // Step 5: Find the best network
    compute_best_network(&best_ordering, &best_parents, p)
}

fn compute_local_scores(v: usize, p: usize, score: &Bic) -> Vec<f64> {
    (0..pow2(p - 1))
        .map(|parents| score.local_score(v, get_parents(v, parents, p)))
        .collect()
}

// Get the parent set for a vertex `i` and subset `j`
fn get_parents(i: usize, j: usize, p: usize) -> Vec<usize> {
    (0..p)
        .filter(|&k| (k < i && pow2(k) & j != 0) || (k > i && pow2(k - 1) & j != 0))
        .collect()
}

// Get the subset of elements
fn pow2(x: usize) -> usize {
    1 << x
}

// Compute the best parent sets
fn compute_best_parents(local_scores: &[f64], p: usize) -> Vec<usize> {
    let mut best_scores = local_scores.to_vec();
    let mut best_parents: Vec<usize> = (0..pow2(p - 1)).collect();
    for j in 0..pow2(p - 1) {
        for k in 0..p - 1 {
            let subj = j & !(1 << k);
            if subj == j {
                continue;
            }
            if best_scores[subj] > best_scores[j] {
                best_scores[j] = best_scores[subj];
                best_parents[j] = best_parents[subj];
            }
        }
    }
    best_parents
}

// Compute the best sinks
fn compute_best_sinks(
    best_parents: &[Vec<usize>],
    local_scores: &[Vec<f64>],
    p: usize,
) -> Vec<usize> {
    let mut best_scores = vec![0.0; pow2(p)];
    let mut best_sinks = vec![usize::MAX; pow2(p)];

    for j in 0..pow2(p) {
        for k in 0..p {
            if (1 << k) & j == 0 {
                continue;
            }
            let subj = j & !(1 << k);
            let before = subj & ((1 << k) - 1);
            let after = subj - before;
            let psubj = before + (after >> 1);
            let subscore = best_scores[subj] + local_scores[k][best_parents[k][psubj]];
            if best_sinks[j] == usize::MAX || subscore > best_scores[j] {
                best_scores[j] = subscore;
                best_sinks[j] = k;
            }
        }
    }
    best_sinks
}

// Compute the ordering of vertices
fn compute_best_ordering(bestsinks: &[usize], p: usize) -> Vec<usize> {
    let mut ordering = vec![0; p];
    let mut remaining = pow2(p) - 1;

    for i in (0..p).rev() {
        ordering[i] = bestsinks[remaining];
        remaining -= 1 << ordering[i];
    }

    ordering
}

// Compute the network (directed graph)
fn compute_best_network(ordering: &[usize], bestparents: &[Vec<usize>], p: usize) -> Dag {
    let mut g = Dag::new(p);
    let mut predecessors = 0;
    for i in 0..p {
        let before = predecessors & ((1 << ordering[i]) - 1);
        let after = predecessors - before;
        let idx = before + (after >> 1);
        let parents = get_parents(ordering[i], bestparents[ordering[i]][idx], p);

        for pa in parents {
            g.add_edge(pa, ordering[i]);
        }

        predecessors += 1 << ordering[i];
    }
    g
}

#[derive(Debug)]
struct Bic {
    n: usize,
    lambda: f64,
    cov: DMatrix<f64>,
}

impl Bic {
    fn new(data: &DMatrix<f64>, penalty: f64) -> Self {
        Self {
            n: data.nrows(),
            lambda: penalty,
            cov: cov_matrix(data),
        }
    }

    fn local_score(&self, v: usize, parents: Vec<usize>) -> f64 {
        let num_parents = parents.len();
        let mut parents_v = parents;
        parents_v.push(v);
        let cholesky = submatrix(&self.cov, &parents_v, &parents_v)
            .cholesky()
            .unwrap();
        let std_var = cholesky.l_dirty()[(num_parents, num_parents)];
        self.compute_local_bic(num_parents, std_var)
    }

    fn compute_local_bic(&self, num_parents: usize, std_var: f64) -> f64 {
        -2.0 * self.n as f64 * std_var.max(f64::MIN_POSITIVE).ln()
            - self.lambda * num_parents as f64 * (self.n as f64).ln()
    }
}

fn cov_matrix(data: &DMatrix<f64>) -> DMatrix<f64> {
    let n = data.nrows();
    let mean_vector = data.row_mean();
    let mut centered_data = data.clone();
    for mut row in centered_data.row_iter_mut() {
        row -= mean_vector.clone();
    }
    (centered_data.transpose() * centered_data) / n as f64
}

fn submatrix(matrix: &DMatrix<f64>, rows: &[usize], cols: &[usize]) -> DMatrix<f64> {
    DMatrix::from_fn(rows.len(), cols.len(), |i, j| matrix[(rows[i], cols[j])])
}

#[derive(Debug)]
struct Dag {
    p: usize,
    edges: Vec<(usize, usize)>,
}

impl Dag {
    fn new(p: usize) -> Self {
        Self {
            p,
            edges: Vec::new(),
        }
    }

    fn add_edge(&mut self, u: usize, v: usize) {
        self.edges.push((u, v));
    }

    fn output(&self) {
        let p = self.p;
        let m = self.edges.len();
        println!("{p} {m} dag");
        for &(u, v) in self.edges.iter() {
            println!("{} {} directed", u, v);
        }
    }
}
