use std::iter::FusedIterator;

use derive_more::Deref;
use rapidhash::{HashMapExt, RapidHashMap};
use thiserror::Error;

use crate::{
    Dictionary, RuleId, TokenId, bpe_with_heap_last_merge, dict::RuleIdVec, typed_vec::TypedVec,
    vocab::TokenIdVec,
};

#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum NormalizedDictBuildError {
    #[error("multiple atomic token sequences for token {token_id} ({seq_a:?} vs {seq_b:?})")]
    MultipleAtomicTokenSeq {
        token_id: TokenId,
        seq_a: Vec<TokenId>,
        seq_b: Vec<TokenId>,
    },
    #[error("improper rules for token {token_id} (proper result: {proper:?})")]
    ImproperDict {
        token_id: TokenId,
        proper: Vec<TokenId>,
    },
}

#[derive(Clone, Debug, Deref)]
pub struct NormalizedDict {
    #[deref]
    dict: Dictionary,
    pub(crate) priorities: TypedVec<TokenId, RuleId>,
    #[cfg(test)]
    pub(crate) canonical_rules: RapidHashMap<(TokenId, TokenId), RuleId>,
}

pub(crate) const ATOMIC_TOKEN_PRIORITY: RuleId = {
    let mut priority = RuleId::MAX;
    *priority.inner_mut() = (priority.inner() >> 1) + 1;
    priority
};

#[inline(always)]
fn to_atomic_token_id(rule_id: RuleId) -> TokenId {
    debug_assert!(rule_id >= ATOMIC_TOKEN_PRIORITY);
    TokenId::new((rule_id - ATOMIC_TOKEN_PRIORITY).inner())
}

impl NormalizedDict {
    pub fn new<F: FnMut(&Dictionary, TokenId, &[u8]) -> bool>(
        dict: Dictionary,
        mut is_atomic: F,
    ) -> Result<Self, NormalizedDictBuildError> {
        let capacity = dict.num_of_tokens();
        let mut priorities = TypedVec::new_with(RuleId::MAX, capacity);
        let mut canonical_rules = RapidHashMap::with_capacity(capacity.as_usize());

        let mut atomic_seqs = TypedVec::new_with(TokenIdVec::new(), capacity);

        for (token_id, priority) in priorities.enumerate_mut() {
            let token = &dict[token_id];
            if token.is_empty() {
                continue;
            }
            if is_atomic(&dict, token_id, token) {
                atomic_seqs[token_id].push(token_id);
                debug_assert!(token_id.as_usize() < ATOMIC_TOKEN_PRIORITY.as_usize());
                let mut p = ATOMIC_TOKEN_PRIORITY;
                *p.inner_mut() += token_id.inner();
                *priority = p;
            }
        }

        let mut token_to_rules = TypedVec::new_with(RuleIdVec::new(), capacity);
        for (rule_id, rule) in dict.rules.enumerate() {
            token_to_rules[rule.merged].push(rule_id);
        }
        for token_id in {
            let mut order: Vec<_> = dict.tokens.keys().collect();
            order.sort_by_key(|&i| dict[i].len());
            order
        } {
            for &rule_id in &token_to_rules[token_id] {
                let rule = &dict[rule_id];
                if atomic_seqs[rule.pre].is_empty() || atomic_seqs[rule.suc].is_empty() {
                    continue;
                }
                let mut seq = atomic_seqs[rule.pre].clone();
                seq.extend_from_slice(&atomic_seqs[rule.suc]);
                let slot = &mut atomic_seqs[token_id];
                if !slot.is_empty() && *slot != seq {
                    return Err(NormalizedDictBuildError::MultipleAtomicTokenSeq {
                        token_id,
                        seq_a: slot.to_vec(),
                        seq_b: seq.to_vec(),
                    });
                }
                *slot = seq;
            }
        }
        drop(token_to_rules);

        let mut validation = TypedVec::new_with(false, dict.num_of_rules());
        for (token_id, seq) in atomic_seqs.enumerate() {
            if seq.is_empty() {
                continue;
            }
            let improper = bpe_with_heap_last_merge::<true>(&dict, seq.to_vec());
            if improper.0 != vec![token_id] {
                continue;
            }
            let proper = bpe_with_heap_last_merge::<false>(&dict, seq.to_vec());
            if proper != improper {
                return Err(NormalizedDictBuildError::ImproperDict {
                    token_id,
                    proper: proper.0,
                });
            }
            if let Some(last_rule_id) = proper.1 {
                validation[last_rule_id] = true;
            }
        }
        drop(atomic_seqs);

        'outer: for (id, rule) in dict.rules.enumerate() {
            let mut left = priorities[rule.pre];
            let mut right = priorities[rule.suc];
            if priorities[rule.merged] != RuleId::MAX || left == RuleId::MAX || right == RuleId::MAX
            {
                continue;
            }
            while left < ATOMIC_TOKEN_PRIORITY || right < ATOMIC_TOKEN_PRIORITY {
                let (u, v): (TokenId, TokenId);
                if left == right {
                    u = dict[left].suc;
                    v = dict[right].pre;
                } else if left >= ATOMIC_TOKEN_PRIORITY {
                    u = to_atomic_token_id(left);
                    v = dict[right].pre;
                    debug_assert_eq!(left, priorities[u]);
                } else if right >= ATOMIC_TOKEN_PRIORITY {
                    u = dict[left].suc;
                    v = to_atomic_token_id(right);
                    debug_assert_eq!(right, priorities[v]);
                } else if left > right {
                    u = dict[left].suc;
                    v = dict[right].merged;
                    debug_assert_eq!(right, priorities[v]);
                } else {
                    u = dict[left].merged;
                    v = dict[right].pre;
                    debug_assert_eq!(left, priorities[u]);
                }
                if let Some(&mid) = canonical_rules.get(&(u, v)) {
                    debug_assert!(priorities[u] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[u]);
                    debug_assert!(priorities[v] >= ATOMIC_TOKEN_PRIORITY || mid > priorities[v]);
                    if left == right || right == priorities[v] {
                        if mid < left {
                            continue 'outer;
                        }
                    } else if mid <= right {
                        continue 'outer;
                    }
                }
                if left < ATOMIC_TOKEN_PRIORITY {
                    left = priorities[u];
                }
                if right < ATOMIC_TOKEN_PRIORITY {
                    right = priorities[v];
                }
                debug_assert_ne!(left, RuleId::MAX);
                debug_assert_ne!(right, RuleId::MAX);
            }
            priorities[rule.merged] = id;
            let res = canonical_rules.insert((rule.pre, rule.suc), id);
            debug_assert!(res.is_none());
            debug_assert!(validation[id]);
            validation[id] = false;
        }

        debug_assert!(validation.into_iter().all(|i| !i));

        Ok(Self {
            dict,
            priorities,
            #[cfg(test)]
            canonical_rules,
        })
    }

    #[inline]
    pub fn new_in_bytes(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
        Self::new(dict, |_, _, b| b.len() == 1)
    }

    #[inline]
    pub fn new_in_utf8(dict: Dictionary) -> Result<Self, NormalizedDictBuildError> {
        Self::new(dict, |_, _, b| {
            if b.len() > 4 {
                return false;
            }
            std::str::from_utf8(b).is_ok_and(|s| s.chars().count() == 1)
        })
    }

    #[inline(always)]
    pub fn priority(&self, token_id: TokenId) -> RuleId {
        self.priorities
            .get(token_id)
            .copied()
            .unwrap_or(RuleId::MAX)
    }

    #[inline(always)]
    pub fn is_atomic(&self, token_id: TokenId) -> bool {
        self.is_canonical(token_id) && self.priorities[token_id] >= ATOMIC_TOKEN_PRIORITY
    }

    #[inline(always)]
    pub fn is_canonical(&self, token_id: TokenId) -> bool {
        self.priority(token_id) != RuleId::MAX
    }

    #[inline(always)]
    pub fn iter_canonical_or_empty_tokens(
        &self,
    ) -> impl DoubleEndedIterator<Item = &[u8]> + ExactSizeIterator + FusedIterator {
        self.tokens.enumerate().map(|(token_id, bytes)| {
            if self.is_canonical(token_id) {
                bytes.as_ref()
            } else {
                &[]
            }
        })
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        Dictionary, NormalizedDict, NormalizedDictBuildError, RuleId, Vocab, bpe_with_heap,
        test_utils::{bytes_into_tokens, utf8_into_tokens},
    };

    fn build_dict<T: AsRef<[u8]>, R: IntoIterator<Item = (T, T)>>(
        vocab: &Vocab,
        rules: R,
    ) -> Dictionary {
        Dictionary::new_from_token_pair(vocab.clone(), rules).unwrap()
    }

    fn build_in_bytes(dict: &Dictionary) -> Option<NormalizedDict> {
        let dict = match NormalizedDict::new_in_bytes(dict.clone()) {
            Ok(dict) => dict,
            Err(NormalizedDictBuildError::ImproperDict { .. }) => {
                return None;
            }
            Err(e) => {
                dbg!(e);
                unreachable!();
            }
        };
        for rule in &dict.rules {
            let token_id = rule.merged;
            assert!(!dict.is_atomic(token_id));
            let seq = &dict[token_id];
            let res = bpe_with_heap::<false>(&dict, bytes_into_tokens(&dict, seq, 0usize));
            assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
        }
        Some(dict)
    }

    fn build_in_utf8(dict: &Dictionary) -> Option<NormalizedDict> {
        let dict = match NormalizedDict::new_in_utf8(dict.clone()) {
            Ok(dict) => dict,
            Err(NormalizedDictBuildError::ImproperDict { .. }) => {
                return None;
            }
            Err(e) => {
                dbg!(e);
                unreachable!();
            }
        };
        for rule in &dict.rules {
            let token_id = rule.merged;
            let seq = match std::str::from_utf8(&dict[token_id]) {
                Ok(seq) => seq,
                Err(_) => {
                    assert!(!dict.is_canonical(token_id));
                    continue;
                }
            };
            assert!(!dict.is_atomic(token_id));
            let res = bpe_with_heap::<false>(&dict, utf8_into_tokens(&dict, seq, 0usize));
            assert!(dict.is_canonical(token_id) ^ (res != vec![token_id]));
        }
        Some(dict)
    }

    fn canonical_rules<R: IntoIterator<Item = u32>>(dict: &NormalizedDict, rules: R) {
        let mut rules: Vec<_> = rules.into_iter().map(RuleId::new).collect();
        rules.sort();
        let mut expected: Vec<_> = dict.canonical_rules.values().copied().collect();
        expected.sort();
        assert_eq!(rules, expected);
    }

    fn build_and_test_rules<R: IntoIterator<Item = u32> + Clone>(dict: &Dictionary, rules: R) {
        if let Some(normalized) = build_in_bytes(dict) {
            canonical_rules(&normalized, rules.clone());
        }
        if let Some(normalized) = build_in_utf8(dict) {
            canonical_rules(&normalized, rules);
        }
    }

    #[test]
    fn test_normalized_dict() {
        let vocab = Vocab::new([
            b"" as &[_],
            b"a",
            b"b",
            b"c",
            b"d",
            b"cd",
            b"bcd",
            b"abcd",
            "你".as_bytes(),
            "好".as_bytes(),
            "呀".as_bytes(),
            "你好".as_bytes(),
            "你好呀".as_bytes(),
            "好你".as_bytes(),
            b"\xe4",
            b"\xbd",
            b"\xa0",
            b"\xbd\xa0",
            b"aa",
            b"aaa",
            b"aaaa",
            b"aaaaa",
        ])
        .unwrap();

        let dict = build_dict(&vocab, [("c", "d"), ("b", "cd"), ("a", "bcd")]);
        build_and_test_rules(&dict, [0, 1, 2]);

        let dict = build_dict(
            &vocab,
            [(b"\xbd" as &[_], b"\xa0" as &[_]), (b"\xe4", b"\xbd\xa0")],
        );
        let normalized = build_in_bytes(&dict).unwrap();
        canonical_rules(&normalized, [0, 1]);

        let dict = build_dict(&vocab, [("aa", "a"), ("a", "a")]);
        build_and_test_rules(&dict, [1]);

        let dict = build_dict(&vocab, [("a", "aa"), ("a", "a")]);
        build_and_test_rules(&dict, [1]);

        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a")]);
        build_and_test_rules(&dict, [0, 1]);

        let dict = build_dict(&vocab, [("a", "a"), ("a", "aa")]);
        build_and_test_rules(&dict, [0]);

        let dict = build_dict(
            &vocab,
            [
                ("a", "a"),
                ("aa", "a"),
                ("a", "aa"),
                ("aa", "aa"),
                ("a", "aaa"),
                ("aaa", "a"),
            ],
        );
        build_and_test_rules(&dict, [0, 1, 3]);

        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aaa", "a")]);
        build_and_test_rules(&dict, [0, 1]);

        let dict = build_dict(&vocab, [("a", "a"), ("aa", "a"), ("aa", "aa")]);
        build_and_test_rules(&dict, [0, 1, 2]);
        let dict = build_dict(&vocab, [("a", "a"), ("aa", "aa"), ("aa", "a")]);
        build_and_test_rules(&dict, [0, 1, 2]);

        let dict = build_dict(
            &vocab,
            [
                ("a", "a"),
                ("aa", "aa"),
                ("aa", "a"),
                ("aaa", "aa"),
                ("aa", "aaa"),
                ("aaaa", "a"),
            ],
        );
        build_and_test_rules(&dict, [0, 1, 2, 5]);

        let dict = build_dict(
            &vocab,
            [
                ("a", "a"),
                ("aa", "a"),
                ("aa", "aa"),
                ("aaa", "aa"),
                ("aa", "aaa"),
                ("aaaa", "a"),
            ],
        );
        build_and_test_rules(&dict, [0, 1, 2, 4]);

        let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀")]);
        let normalized = build_in_utf8(&dict).unwrap();
        canonical_rules(&normalized, [0, 1]);
        let dict = build_dict(&vocab, [("你", "好"), ("你好", "呀"), ("好", "你")]);
        let normalized = build_in_utf8(&dict).unwrap();
        canonical_rules(&normalized, [0, 1, 2]);
        let dict = build_dict(&vocab, [("你", "好"), ("好", "你"), ("你好", "呀")]);
        let normalized = build_in_utf8(&dict).unwrap();
        canonical_rules(&normalized, [0, 1, 2]);
        let dict = build_dict(&vocab, [("好", "你"), ("你", "好"), ("你好", "呀")]);
        let normalized = build_in_utf8(&dict).unwrap();
        canonical_rules(&normalized, [0, 1, 2]);
        let dict = build_dict(&vocab, [("你好", "呀"), ("你", "好"), ("好", "你")]);
        assert!(build_in_utf8(&dict).is_none());
        let dict = build_dict(&vocab, [("你好", "呀"), ("好", "你"), ("你", "好")]);
        assert!(build_in_utf8(&dict).is_none());
        let dict = build_dict(&vocab, [("好", "你"), ("你好", "呀"), ("你", "好")]);
        assert!(build_in_utf8(&dict).is_none());

        let vocab = Vocab::new([
            b"" as &[_],
            b"a",
            b"abc",
            b"abcde",
            b"abcdef",
            b"b",
            b"ba",
            b"bc",
            b"bcdef",
            b"c",
            b"cd",
            b"cde",
            b"cdefg",
            b"d",
            b"de",
            b"def",
            b"e",
            b"ef",
            b"efg",
            b"f",
            b"g",
        ])
        .unwrap();
        let dict = build_dict(
            &vocab,
            [
                ("b", "c"),
                ("e", "f"),
                ("d", "e"),
                ("c", "d"),
                ("d", "ef"),
                ("b", "a"),
                ("a", "bc"),
                ("abc", "de"),
                ("abc", "def"),
                ("bc", "def"),
                ("c", "de"),
                ("ef", "g"),
                ("cd", "efg"),
            ],
        );
        build_and_test_rules(&dict, 0..13);
        let dict = build_dict(
            &vocab,
            [
                ("b", "c"),
                ("e", "f"),
                ("d", "e"),
                ("c", "d"),
                ("d", "ef"),
                ("a", "bc"),
                ("b", "a"),
                ("abc", "de"),
                ("abc", "def"),
                ("bc", "def"),
                ("c", "de"),
                ("ef", "g"),
                ("cd", "efg"),
            ],
        );
        build_and_test_rules(&dict, 0..13);
    }

    #[test]
    fn test_normalized_dict_invalid() {
        let dict = Dictionary::new_from_id_pair(
            Vocab::new([b"a" as &[_], b"aa"]).unwrap(),
            [(0usize, 0usize)],
        )
        .unwrap();
        let res = NormalizedDict::new(dict.clone(), |_, _, b| b.len() == 1);
        assert!(res.is_ok());
        let res = NormalizedDict::new(dict, |_, _, _| true);
        assert!(res.is_err());
    }
}
