use crate::{
    Dictionary, RuleId, TokenId,
    sp_impl::heap::AdjustableHeap,
    typed_vec::{TypedVec, typed_vec_index},
};

typed_vec_index!(pub(super) InputTextPos, u32);

type Heap = AdjustableHeap<InputTextPos, RuleId>;

pub fn bpe_with_heap<const ALLOW_IMPROPER_RULES: bool>(
    dict: &Dictionary,
    seq: impl Into<Vec<TokenId>>,
) -> Vec<TokenId> {
    bpe_with_heap_last_merge::<ALLOW_IMPROPER_RULES>(dict, seq).0
}

pub fn bpe_with_heap_last_merge<const ALLOW_IMPROPER_RULES: bool>(
    dict: &Dictionary,
    seq: impl Into<Vec<TokenId>>,
) -> (Vec<TokenId>, Option<RuleId>) {
    let seq = seq.into();
    let seq_len = seq.len();
    if seq_len <= 1 {
        return (seq, None);
    }

    let mut next_split_pos: TypedVec<InputTextPos, _> =
        (1..seq_len + 1).map(InputTextPos::from).collect();
    debug_assert!(next_split_pos.enumerate_copied().all(|(i, j)| i < j));
    let mut prev_split_pos =
        TypedVec::new_with(InputTextPos::ZERO, InputTextPos::from(seq_len + 1));
    for (cur, next) in next_split_pos.enumerate_copied() {
        if next < prev_split_pos.len() {
            prev_split_pos[next] = cur;
        }
    }

    let mut seq = TypedVec::from(seq);
    let mut heap = Heap::new(
        seq.len(),
        seq.as_slice()
            .windows(2)
            .enumerate()
            .flat_map(|(i, window)| {
                let left = window[0];
                let right = window[1];
                dict.find_rule(left, right)
                    .map(|rule_id| (InputTextPos::from(i + 1), rule_id))
            }),
    );

    let mut last_merge = None;
    while let Some((pos, rule_id)) = heap.pop() {
        last_merge = Some(rule_id);
        let rule = &dict[rule_id];

        debug_assert!(
            InputTextPos::ZERO < pos && pos < prev_split_pos.len() && pos < next_split_pos.len()
        );

        let prev_pos = prev_split_pos[pos];
        let next_pos = next_split_pos[pos];

        debug_assert!(prev_pos < pos && pos < next_pos);

        debug_assert_eq!(seq[prev_pos], rule.pre);
        debug_assert_eq!(seq[pos], rule.suc);

        seq[prev_pos] = rule.merged;

        debug_assert!(prev_pos < next_split_pos.len());
        debug_assert!(InputTextPos::ZERO < next_pos && next_pos < prev_split_pos.len());

        next_split_pos[prev_pos] = next_pos;
        prev_split_pos[next_pos] = prev_pos;

        if prev_pos > InputTextPos::ZERO {
            if let Some(new_rule_id) = dict.find_rule(seq[prev_split_pos[prev_pos]], rule.merged)
                && (ALLOW_IMPROPER_RULES || new_rule_id > rule_id)
            {
                heap.set(prev_pos, new_rule_id);
            } else {
                heap.remove(prev_pos);
            }
        }

        if next_pos < seq.len() {
            if let Some(new_rule_id) = dict.find_rule(rule.merged, seq[next_pos])
                && (ALLOW_IMPROPER_RULES || new_rule_id > rule_id)
            {
                heap.set(next_pos, new_rule_id);
            } else {
                heap.remove(next_pos);
            }
        }
    }

    let mut res = Vec::with_capacity(seq.len().as_usize());
    let mut cur = InputTextPos::ZERO;
    while cur < next_split_pos.len() {
        res.push(seq[cur]);
        cur = next_split_pos[cur];
    }
    (res, last_merge)
}

#[cfg(test)]
mod tests {
    use crate::{
        Dictionary, TokenId, Vocab, bpe_with_heap,
        test_utils::{bytes_into_tokens, utf8_into_tokens},
    };

    fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
        vocab: &Vocab,
        rules: R,
    ) -> Dictionary {
        Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
    }

    fn check_in_bytes<S: AsRef<[u8]>, I: Into<TokenId>, T: IntoIterator<Item = I>>(
        dict: &Dictionary,
        seq: S,
        tokens: T,
    ) {
        let seq = seq.as_ref();
        let tokens: Vec<_> = tokens.into_iter().map(I::into).collect();
        let inputs = bytes_into_tokens(dict, seq, 0usize);
        assert_eq!(bpe_with_heap::<true>(dict, inputs), tokens);
        check_properly_in_bytes(dict, seq, tokens);
    }

    fn check_in_utf8<S: AsRef<str>, I: Into<TokenId>, T: IntoIterator<Item = I>>(
        dict: &Dictionary,
        seq: S,
        tokens: T,
    ) {
        let seq = seq.as_ref();
        let tokens: Vec<_> = tokens.into_iter().map(I::into).collect();
        let inputs = utf8_into_tokens(dict, seq, 0usize);
        assert_eq!(bpe_with_heap::<true>(dict, inputs), tokens);
        check_properly_in_utf8(dict, seq, tokens);
    }

    fn check_properly_in_bytes<S: AsRef<[u8]>, I: Into<TokenId>, T: IntoIterator<Item = I>>(
        dict: &Dictionary,
        seq: S,
        tokens: T,
    ) {
        let tokens: Vec<_> = tokens.into_iter().map(I::into).collect();
        let inputs = bytes_into_tokens(dict, seq, 0usize);
        assert_eq!(bpe_with_heap::<false>(dict, inputs), tokens);
    }

    fn check_properly_in_utf8<S: AsRef<str>, I: Into<TokenId>, T: IntoIterator<Item = I>>(
        dict: &Dictionary,
        seq: S,
        tokens: T,
    ) {
        let tokens: Vec<_> = tokens.into_iter().map(I::into).collect();
        let inputs = utf8_into_tokens(dict, seq, 0usize);
        assert_eq!(bpe_with_heap::<false>(dict, inputs), tokens);
    }

    #[test]
    fn test_bpe_with_heap() {
        let vocab = Vocab::new([
            b"" as &[_],
            b"a",
            b"b",
            b"c",
            b"d",
            b"cd",
            b"bcd",
            b"abcd",
            "你".as_bytes(),
            "好".as_bytes(),
            "呀".as_bytes(),
            "你好".as_bytes(),
            "你好呀".as_bytes(),
            "好你".as_bytes(),
            b"\xe4",
            b"\xbd",
            b"\xa0",
            b"\xbd\xa0",
            b"aa",
            b"aaa",
        ])
        .unwrap();

        let dict = build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
        check_in_bytes(&dict, "dcdbcdabcdab", [4u32, 5, 6, 7, 1, 2]);
        check_in_utf8(&dict, "dcdbcdabcdab", [4u32, 5, 6, 7, 1, 2]);

        let dict = build_dict(
            &vocab,
            [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
        );
        check_in_bytes(&dict, "你好", [8u32, 0, 0, 15]);
        check_properly_in_utf8(&dict, "你好", [8u32, 9]);
        let output = utf8_into_tokens(&dict, "你好", 0usize);
        assert_eq!(output, [8, 9].map(TokenId::new));

        let dict = build_dict(&vocab, [("你", "好")]);
        check_in_utf8(&dict, "你好", [11u32]);
        check_in_utf8(&dict, "你好呀", [11u32, 10]);
        check_in_utf8(
            &dict,
            "你好你好你好呀你好你好你",
            [11u32, 11, 11, 10, 11, 11, 8],
        );

        let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
        check_in_utf8(&dict, "你好", [11u32]);
        check_in_utf8(&dict, "你好呀", [12u32]);
        check_in_utf8(
            &dict,
            "你好你好你好呀你好你好你",
            [11u32, 11, 12, 11, 11, 8],
        );

        let verify = |rules: &[(&str, &str)], seq: &str, tokens: &[u32]| {
            let dict = build_dict(&vocab, rules.iter().copied());
            check_in_utf8(&dict, seq, tokens.iter().copied());
        };

        verify(&[("你", "好"), ("你好", "呀")], "", &[]);
        verify(&[("你", "好"), ("你好", "呀")], "你", &[8]);

        let long_case = "好你好你好呀你好你好你";
        verify(
            &[("你", "好"), ("你好", "呀"), ("好", "你")],
            long_case,
            &[9, 11, 12, 11, 11, 8],
        );
        verify(
            &[("你", "好"), ("好", "你"), ("你好", "呀")],
            long_case,
            &[9, 11, 12, 11, 11, 8],
        );
        verify(
            &[("好", "你"), ("你", "好"), ("你好", "呀")],
            long_case,
            &[13, 13, 9, 10, 8, 13, 13],
        );

        let long_case = "你好你好你好呀你好你好你";
        verify(
            &[("你", "好"), ("你好", "呀"), ("好", "你")],
            long_case,
            &[11, 11, 12, 11, 11, 8],
        );
        verify(
            &[("你", "好"), ("好", "你"), ("你好", "呀")],
            long_case,
            &[11, 11, 12, 11, 11, 8],
        );
        verify(
            &[("好", "你"), ("你", "好"), ("你好", "呀")],
            long_case,
            &[8, 13, 13, 9, 10, 8, 13, 13],
        );

        verify(&[("a", "a")], "aaaaa", &[18, 18, 1]);
        verify(&[("a", "a")], "aaaaaa", &[18, 18, 18]);

        verify(&[("a", "a"), ("aa", "a")], "aaaaa", &[18, 19]);
        verify(&[("a", "a"), ("aa", "a")], "aaaaaa", &[18, 18, 18]);
        verify(&[("a", "a"), ("aa", "a")], "aaaaaaa", &[18, 18, 19]);

        verify(&[("a", "a"), ("a", "aa")], "aaaaa", &[18, 18, 1]);
        verify(&[("a", "a"), ("a", "aa")], "aaaaaa", &[18, 18, 18]);
        verify(&[("a", "a"), ("a", "aa")], "aaaaaaa", &[18, 18, 18, 1]);

        let check_properly = |rules: &[(&str, &str)], seq: &str, tokens: &[u32]| {
            let dict = build_dict(&vocab, rules.iter().copied());
            check_properly_in_utf8(&dict, seq, tokens.iter().copied());
        };

        let long_case = "好你好你好呀你好你好你";
        check_properly(
            &[("你好", "呀"), ("你", "好"), ("好", "你")],
            long_case,
            &[9, 11, 11, 10, 11, 11, 8],
        );
        check_properly(
            &[("你好", "呀"), ("好", "你"), ("你", "好")],
            long_case,
            &[13, 13, 9, 10, 8, 13, 13],
        );
        check_properly(
            &[("好", "你"), ("你好", "呀"), ("你", "好")],
            long_case,
            &[13, 13, 9, 10, 8, 13, 13],
        );

        let long_case = "你好你好你好呀你好你好你";
        check_properly(
            &[("你好", "呀"), ("你", "好"), ("好", "你")],
            long_case,
            &[11, 11, 11, 10, 11, 11, 8],
        );
        check_properly(
            &[("你好", "呀"), ("好", "你"), ("你", "好")],
            long_case,
            &[8, 13, 13, 9, 10, 8, 13, 13],
        );
        check_properly(
            &[("好", "你"), ("你好", "呀"), ("你", "好")],
            long_case,
            &[8, 13, 13, 9, 10, 8, 13, 13],
        );

        check_properly(&[("aa", "a"), ("a", "a")], "aaaaa", &[18, 18, 1]);
        check_properly(&[("aa", "a"), ("a", "a")], "aaaaaa", &[18, 18, 18]);
        check_properly(&[("aa", "a"), ("a", "a")], "aaaaaaa", &[18, 18, 18, 1]);
    }
}
