#pragma once
#include "cassert"
#include <algorithm>
#include <numeric>
#include <array>
#include <type_traits>
#include <span>

namespace combinatorics{

    namespace assist{
        template<typename Function_t, typename BidirIter, typename IdxIter, typename RIter>
        void for_multi_combinations_recursive(Function_t f, BidirIter header, BidirIter begin, BidirIter end, IdxIter idx_begin, const RIter r_begin, const RIter r_end) {
            // header 是组合数组的头，begin是当前组别的头
            int r_remain = r_end - r_begin;
            int upper = end-begin; //当前组别需要排列的个数
            if(r_remain < 0){
                assert(-1);
            }
            else if(r_remain == 0){
                f(header, begin);
            }
            else {
                bool continue_loop = true;
                do{
                    //后根递归代表从后向前处理各组合组别
                    for_multi_combinations_recursive(f, header, begin+*r_begin, end, idx_begin + *r_begin, r_begin+1, r_end);
                    
                    IdxIter idx_cur = idx_begin + *r_begin - 1;
                    //考虑组别中最后一个组合[upper-r, upper-r+1, ..., upper-1]，观察可知[...,x,upper-r+k(第k位),...]后r-k位不用考虑
                    while(idx_cur >= idx_begin && *idx_cur == upper - (idx_begin + *r_begin - idx_cur)) --idx_cur;

                    // 将组别中后几个到达最大组合的元素复位
                    for (int idx_i = *r_begin - 1; idx_i >= idx_cur - idx_begin && idx_i >= 0; --idx_i) {
                        std::swap(*(begin+idx_i), *(begin + idx_begin[idx_i]));
                    }

                    if (idx_cur>=idx_begin){
                        //如果没有达到最后一个组合
                        ++(*idx_cur);
                        std::swap(*(begin + (idx_cur-idx_begin)), *(begin+*idx_cur));
                    }
                    else{
                        ++idx_cur;
                        assert(idx_cur == idx_begin);
                        *idx_cur=0;
                        continue_loop = false;
                    }
                    for(int i = (idx_cur+1-idx_begin); i < *r_begin; ++i){
                        *(idx_begin+i) = *(idx_begin+i-1)+1;
                        std::swap(*(begin + i), *(begin+idx_begin[i]));
                    }
                } while(continue_loop);

            }
        }
    }

    template<typename Function_t, typename BidirIter, typename R>
    void for_multi_combinations(Function_t f, BidirIter begin, BidirIter end, std::span<const R> rs) {
        // rs是各组合分组
        int r_scale = std::accumulate(rs.begin(), rs.end(), 0);
        
        int *indices = new int[r_scale];//indices[r_scale] = [0,...,rs[0]-1; 0,...,rs[1]-1;...;0,...,*rs.end()-1];
        int *p_idx = indices;
        for(auto rit = rs.begin(); rit!=rs.end(); ++rit){
            std::iota(p_idx, p_idx+*rit, 0);
            p_idx += *rit;
        } // 各分组idx从0开始

        assist::for_multi_combinations_recursive(f, begin, begin, end, indices, rs.begin(), rs.end());

        //////////////////////////////////////////////////////////////////////////////
        delete[] indices;
    }

    template<typename Function_t, typename BidirIter>
    void for_combinations(Function_t f, BidirIter begin, BidirIter end, const int r) {

        const int upper = end-begin; //当前排列涉及到的总的对象数

        int *indices = new int[r];//indices[r] = [0,...,r-1];
        std::iota(indices, indices+r, 0);

        bool continue_loop = true;
        do{
            f(begin, begin+r);

            // ================================================
            // 0'1'23                  abcdefghijklmnopqrstuvwxyz
            // d'1'23                     0efghijklmnopqrstuvwxyz
            // d'i'23                     0    1jklmnopqrstuvwxyz
            // a'i'y3                     0    1               2z
            // a'i'yz                     0    1               23
            // 字母是当前序列可选范围，且后几位一定是连续的，下面这个循环跳出后idx_cur应指向'i'那个位置
            int *idx_cur = indices + r-1;
            while(idx_cur >= indices && *idx_cur == upper - (indices + r - idx_cur)) --idx_cur;

            // 恢复成
            //a'1'23                     0    ijklmnopqrstuvwxyz
            for (int idx_i = r - 1; idx_i >= idx_cur - indices && idx_i >= 0; --idx_i) {
                std::swap(*(begin+idx_i), *(begin + indices[idx_i]));
            }

            if (idx_cur>=indices){
                //如果没有达到最后一个组合，即indices[0]<upper-r，则遍历下一个组合
                //a'j'23                     0     1klmnopqrstuvwxyz
                ++(*idx_cur);
                std::swap(*(begin + (idx_cur-indices)), *(begin+*idx_cur));
            }
            else{
                // 一定恢复成
                // 0123                  abcdefghijklmnopqrstuvwxyz
                // idx_cur在0的前面，为了统一化加上来
                ++idx_cur;
                assert(idx_cur == indices);
                *idx_cur=0;//这个有点多余吧
                continue_loop = false;
            }

            // 将indice在idx_cur位置以后的值生成为初始，即*idx_cur, *idx_cur+1, ..., *idx_cur+(r-(idx_cur-indices*))
            for(int i = (idx_cur+1-indices); i < r; ++i){
                *(indices+i) = *(indices+i-1)+1;
                std::swap(*(begin + i), *(begin+indices[i]));
            }
        }while(continue_loop);

        delete[] indices;
    }

    template< int M, int N >
    struct COMB
    {
        constexpr static int64_t value = COMB< M - 1, N - 1>::value + COMB< M - 1, N >::value;
        // enum { value = COMB< M - 1, N - 1>::value + COMB< M - 1, N >::value };
    };
    
    template< int M >
    struct COMB< M, M >
    {
        constexpr static int64_t value = 1 ;
    };
    
    template<int M>
    struct COMB< M, 0 >
    {
        constexpr static int64_t value = 1;
    };

    // template<int M, int N, int... Ns>
    // struct MULTI_COMB{
    //     constexpr static int64_t value = COMB<M,N>::value * MULTI_COMB<M-N, Ns...>::value ;
    // };
    // template<int M, int N>
    // struct MULTI_COMB< M, N>
    // {
    //     constexpr static int64_t value = COMB<M, N>::value;
    // };
// ////////////////////////////int
//     namespace assist{
//         template<int M, int R, const int Ns[R], int I>
//         struct MULTI_COMB_
//         {
//             constexpr static int64_t value = COMB<M, Ns[I-1]>::value *  MULTI_COMB_<M-Ns[I-1], R, Ns, I-1>::value ;
//         };
//         template<int M, int R, const int Ns[R]>
//         struct MULTI_COMB_<M, R, Ns, 0>
//         {
//             constexpr static int64_t value = 1;
//         };
//     }
// ////////////////////////////array

//     template<int M, int R, const int Ns[R]>
//     struct MULTI_COMB
//     {
//         constexpr static int64_t value = assist::MULTI_COMB_<M, R, Ns, R>::value ;
//     };

////////////////////////////int
    // template<int M, std::array Ns, int I>
    // struct MULTI_COMB
    // {
    //     static_assert(Ns.size()>=I);
    //     constexpr static int64_t value = COMB<M, Ns[I-1]>::value *  MULTI_COMB<M-Ns[I-1], Ns, I-1>::value ;
    // };
    // template<int M, std::array Ns>
    // struct MULTI_COMB<M, Ns, 0>
    // {
    //     constexpr static int64_t value = 1;
    // };
////////////////////////////array

    template<int M, std::array Ns, int B = 0, int E = Ns.size()>
    struct MULTI_COMB
    {
        static_assert(Ns.size()>=E && B<=E);
        constexpr static int64_t value = COMB<M, Ns[E-1]>::value *  MULTI_COMB<M-Ns[E-1], Ns, B, E-1>::value ;
    };
    template<int M, std::array Ns, int B>
    struct MULTI_COMB<M, Ns, B, B>
    {
        constexpr static int64_t value = 1;
    };

}


//////////////////////////////////////////test
// #include <iostream>
// #include <functional>
// #include <vector>
// #include <array>
// #include <chrono>
// int main(){
//     const int card_num = 52;
//     int deck[card_num];
//     const int r = 2;
//     std::iota(deck, deck + card_num, 0);
//     int cnt = 0;
//     std::vector<std::array<int,7>> vecs;

//     auto f1 = [&i = cnt](auto first, auto second){
//                                                  ++i;
//                                                  std::cout<<i<<" : ";
//                                                  for(auto it = first; it!=second; ++it)
//                                                      std::cout<<*it<<' ';
//                                                  std::cout<<'\n';
//                                              };
//     auto f2 = [&vecs](auto first, auto second){
//                                                  std::array<int,7> temp;
//                                                  std::copy(first, second, temp.begin());
//                                                  vecs.push_back(temp);
//                                              };
//     auto f3 = [&i = cnt](auto first, auto second){
//                                                  ++i;
//                                              };

//     auto start_t = std::chrono::high_resolution_clock::now();
//     combinatorics::for_multi_combinations( f3
//                                          , deck
//                                          , deck + card_num
//                                          , {2,2,3}
//                                          );

//     auto end_t = std::chrono::high_resolution_clock::now();
//     auto duration_t = std::chrono::duration_cast<std::chrono::milliseconds>(end_t - start_t).count();
//     std::cout << "\nParallel Duration: " << duration_t << " ms" << "\n\n";

//     start_t = std::chrono::high_resolution_clock::now();
//     for(int j = cnt; j>0; --j){
//         std::swap(deck[0], deck[1]);
//     }
//     end_t = std::chrono::high_resolution_clock::now();
//     duration_t = std::chrono::duration_cast<std::chrono::milliseconds>(end_t - start_t).count();
//     std::cout << "\nParallel Duration: " << duration_t << " ms" << "\n\n";

// }