#pragma once

#include <iostream>
#include <vector>
#include <utility>  // std::pair
#include <tuple>
#include <iterator>  // for std::distance
#include <numeric>
#include <algorithm>
#include "Simplex_tree_multi_interface.h"
#include <gudhi/Simplex_tree/multi_filtrations/Finitely_critical_filtrations.h>
#include <oneapi/tbb/parallel_for.h>
#include <oneapi/tbb/enumerable_thread_specific.h>
#include "tensor/tensor.h"
#include "multi_parameter_rank_invariant/persistence_slices.h"


namespace Gudhi::multiparameter::hilbert_function{

// TODO : this function is ugly
template<typename fixed_values_type, typename indices_type>
inline value_type horizontal_line_filtration2(const std::vector<value_type> &x, value_type height, indices_type i, indices_type j, const std::vector<fixed_values_type>& fixed_values){
	for (indices_type k = 0u; k < static_cast<indices_type>(x.size()); k++){
		if (k == i || k == j) continue; // coordinate in the plane
		if (x[k] > fixed_values[k]) // simplex appears after the plane
			return std::numeric_limits<Simplex_tree_std::Filtration_value>::infinity();
	}
	if (x[j] <= height) // simplex apppears in the plane, but is it in the line with height "height"
		return x[i];
	else
		return std::numeric_limits<Simplex_tree_std::Filtration_value>::infinity();
}


template<typename index_type>
using hilbert_thread_data = tbb::enumerable_thread_specific<std::pair<Simplex_tree_std, std::vector<index_type>>>;
template<typename dtype, typename index_type>
inline void compute_2d_hilbert_surface(
	Simplex_tree_multi &st_multi,
	// Simplex_tree_std &_st,
	hilbert_thread_data<index_type> &thread_simplex_tree,
	const tensor::static_tensor_view<dtype, index_type>& out, // assumes its a zero tensor
	const std::vector<index_type> grid_shape, 
	const std::vector<index_type> degrees, 
	index_type i, index_type j,
	const std::vector<index_type> fixed_values, 
	bool mobius_inverion, 
	bool zero_pad,
	int expand_collapse_dim=0
	){
	// if (grid_shape.size() < 2 || st_multi.get_number_of_parameters() < 2)
	// 	throw std::invalid_argument("Grid shape has to have at least 2 element.");
	// if (st_multi.get_number_of_parameters() - fixed_values.size() != 2)
	// 	throw std::invalid_argument("Fix more values for the simplextree, which has a too big number of parameters");
	// assert(fixed_values.size() == st_multi.get_number_of_parameters());

	constexpr bool verbose = false;
	index_type I = grid_shape[i+1], J = grid_shape[j+1];
	if constexpr(verbose) std::cout << "Grid shape : " << I << " " << J << std::endl;

	// grid2d out(I, std::vector<int>(J,0)); // zero of good size
	// std::vector<std::vector<index_type>> free_coordinates(grid_shape.size());
	// for (auto [r,k] = std::views::zip(free_coordinates, std::ranges::iota(0,free_coordinates.size()))){
	// 	if (k==i) r = std::ranges::iota(0,I);
	// 	else if (k==j) r= std::ranges::iota(0,J);
	// 	else fixed_values[k];
	// }
	// tensor::static_tensor_view_view<value_type, index_type> dim2_view(out, free_coordinates);
	// auto coordinates_container = 

	// Simplex_tree_std _st;
	// flatten<Simplex_tree_float, Simplex_tree_options_multidimensional_filtration>(_st, st_multi,-1); // copies the st_multi to a standard 1-pers simplextree
	// tbb::enumerable_thread_specific<std::pair<Simplex_tree_std, std::vector<index_type>>> thread_simplex_tree;
	tbb::parallel_for(0, J,[&](index_type height){
		// SIMPLEXTREE INIT
		auto& [st_std, coordinates_container] = thread_simplex_tree.local();
		// if (st_std.num_simplices() == 0){ st_std = _st;}
		// COORDINATES INIT
		// if (coordinates_container.size() == 0) {
		// 	//This init is fine as only the j+1th coord is touched
		// 	coordinates_container.reserve(fixed_values.size()+1);
		// 	coordinates_container.push_back(0); // degree
		// 	for (auto c : fixed_values) coordinates_container.push_back(c);
		// }
		// coordinates_container.resize(fixed_values.size()+1); // Not necessary

		// if (coordinates_container.size() != fixed_values.size()+1 || st_std.num_simplices() == 0){
		// 	throw std::runtime_error("Bad tbb thread local storage initialization.");
		// }
		//Coordinate initialization to fixed values
		// coordinates_container[0] = 0; // not necessary
		// for (auto [c, i] : std::views::zip(fixed_values, std::views::iota(0u, fixed_values.size()))) // NIK APPLE CLANG
			// coordinates_container[i+1] = c;
		for (auto i=0u; i< fixed_values.size(); i++)
			coordinates_container[i+1] = fixed_values[i];
		
		coordinates_container[j+1] = height;

		Simplex_tree_multi::Filtration_value multi_filtration(st_multi.get_number_of_parameters());
		auto sh_standard = st_std.complex_simplex_range().begin();
		auto _end = st_std.complex_simplex_range().end();
		auto sh_multi = st_multi.complex_simplex_range().begin();
		for (;sh_standard != _end; ++sh_multi, ++sh_standard){
		// for (auto [sh_standard, sh_multi] : std::ranges::views::zip(st_std.complex_simplex_range(), st_multi.complex_simplex_range())){ // too bad apple clang exists
			multi_filtration = st_multi.filtration(*sh_multi); 
			value_type horizontal_filtration = horizontal_line_filtration2(multi_filtration, height, i,j, fixed_values);
			st_std.assign_filtration(*sh_standard, horizontal_filtration);

			if constexpr (verbose){
				Simplex_tree_multi::Filtration_value splx;
				for (auto vertex : st_multi.simplex_vertex_range(*sh_multi))	splx.push_back(vertex);
				std::cout << "Simplex " << splx << "/"<< st_std.num_simplices() << " Filtration multi " << st_multi.filtration(*sh_multi) << " Filtration 1d " <<  st_std.filtration(*sh_standard) << "\n";
			}
		}

		if constexpr(verbose) {
			std::cout << "Coords : "  << height << " [";
			for (auto stuff : fixed_values)
				std::cout << stuff << " ";
			std::cout  << "]" << std::endl;
		}
		const std::vector<Barcode> barcodes = compute_dgms(st_std, degrees,expand_collapse_dim);
		index_type degree_index=0;
		for (const auto& barcode : barcodes){ // TODO range view cartesian product
			coordinates_container[0] = degree_index;
			for(const auto &bar : barcode){
				auto birth = bar.first; //float
				auto death = bar.second;
				// if constexpr (verbose) std::cout << "BEFORE " << birth << " " << death << " " << I << " \n";
				// death = death > I ? I : death; // TODO FIXME 
				// if constexpr (verbose) std::cout <<"AFTER" << birth << " " << death << " " << I << " \n";
				if (birth > I) // some birth can be infinite
					continue;
				
				if (!mobius_inverion){
					// throw std::logic_error("Not implemented");
					// death = death > I ? I : death;
					// for (int index = static_cast<int>(birth); index < static_cast<int>(death); index ++){
					// 	out[degree_index][index][height]++;
					// }
					
					// Seems to bug on linux ????
					

					coordinates_container[i+1] = static_cast<index_type>(birth);
					index_type shift_value = out.get_cum_resolution()[i+1];
					index_type border  = I;
					// index_type border  = out.get_resolution()[i+1];
					dtype* ptr = &out[coordinates_container];
					auto stop_value = death > static_cast<value_type>(border) ? border : static_cast<index_type>(death);
					// Warning : for some reason linux static casts float inf to -min_int so min doesnt work.
					if constexpr (verbose) {
						std::cout << "Adding : (";
						for (auto stuff : coordinates_container) std::cout << stuff << ", ";
						std::cout << ") With death " << death << " casted at "<< static_cast<index_type>(death) << "with threshold at" << stop_value << " with "<< border <<std::endl;
					}
					for (index_type b = birth; b < stop_value; b++){
						(*ptr)++; //adds one to the vector
						ptr += shift_value; // shift the pointer to the next element in the segment [birth, death]
						
					}
				}
				else{
					// out[degree_index][static_cast<int>(birth)][height]++; // No need to do mobius inversion on this axis, it can be done here
					// if (death < I)
					// 	out[degree_index][static_cast<int>(death)][height]--;
					// else if (zero_pad)
					// {
					// 	out[degree_index].back()[height]--;
					// }
					// coordinates_container[0] = degree_index;
					coordinates_container[i+1] = static_cast<index_type>(birth);
					out[coordinates_container]++;

					if constexpr (verbose){
						std::cout << "Coordinate : ";
						for (auto c : coordinates_container) std::cout << c << ", ";
						std::cout << std::endl;
						std::cout << "axis, death, resolution : " << i+1 << ", " << std::to_string(death) << ", " << out.get_resolution()[i+1];
						std::cout << std::endl;
					}
						
					if (death < I){		
						coordinates_container[i+1] = static_cast<index_type>(death);
						out[coordinates_container]--;
					}
					else if (zero_pad){
						coordinates_container[i+1] = I-1;
						out[coordinates_container]--;
					}
					
				}
				// else 
				// 	out[I-1][height]--;
			}
			degree_index++;
		}
	});
	return ;
}



template<typename dtype, typename index_type>
void _rec_get_hilbert_surface(
	Simplex_tree_multi &st_multi,
	// Simplex_tree_std &_st,
	hilbert_thread_data<index_type> &thread_simplex_tree,
	const tensor::static_tensor_view<dtype, index_type>& out, // assumes its a zero tensor
	const std::vector<index_type> grid_shape,
	const std::vector<index_type> degrees,
	std::vector<index_type> coordinates_to_compute,
	const std::vector<index_type> fixed_values, 
	bool mobius_inverion = true, 
	bool zero_pad=false,
	int expand_collapse_dim=0
	){
	constexpr bool verbose = false;
	
	if constexpr (verbose) {
		std::cout << "Computing coordinates (";
		for (auto c : coordinates_to_compute) std::cout << c << ", ";
		std::cout << "). with fixed values (";
		for (auto c : fixed_values) {std::cout << c << ", ";}
		std::cout << ")." <<std::endl;
		
	}
	if (coordinates_to_compute.size() == 2){
		compute_2d_hilbert_surface(
			st_multi,
			// _st, 
			thread_simplex_tree,
			out, // assumes its a zero tensor
			grid_shape, 
			degrees, 
			coordinates_to_compute[0], coordinates_to_compute[1],
			fixed_values,
			mobius_inverion, 
			zero_pad,
			expand_collapse_dim
		);
		return;
	}

	// coordinate to iterate.size -- 
	auto coordinate_to_iterate = coordinates_to_compute.back(); 
	coordinates_to_compute.pop_back();
	tbb::parallel_for(0, grid_shape[coordinate_to_iterate+1], [&](index_type z){
		// Updates fixes values that defines the slice
		std::vector<index_type> _fixed_values = fixed_values; // TODO : do not copy this //thread local
		_fixed_values[coordinate_to_iterate] = z;
		_rec_get_hilbert_surface(st_multi, thread_simplex_tree,out,grid_shape, degrees,coordinates_to_compute, _fixed_values, mobius_inverion,zero_pad,expand_collapse_dim);
	});
	// rmq : with mobius_inversion + rec, the coordinates to compute size is 2 => first coord is always the initial 1st coord. 
	// => inversion is only needed for coords > 2
}


template<typename dtype, typename index_type>
void get_hilbert_surface(
	Simplex_tree_multi &st_multi,
	const tensor::static_tensor_view<dtype, index_type>& out, // assumes its a zero tensor
	const std::vector<index_type>& grid_shape,
	const std::vector<index_type>& degrees,
	std::vector<index_type> coordinates_to_compute,
	const std::vector<index_type>& fixed_values, 
	bool mobius_inverion = true, 
	bool zero_pad=false,
	bool expand_collapse=false
	){
	if (degrees.size() == 0) return;
	//wrapper arount the rec version, that initialize the thread variables.
	if (coordinates_to_compute.size() < 2)
		throw std::logic_error("Not implemented for "+ std::to_string(coordinates_to_compute.size()) +  "<2 parameters.");

	Simplex_tree_std _st;
	flatten(_st, st_multi,-1); // copies the st_multi to a standard 1-pers simplextree
	std::vector<index_type> coordinates_container(st_multi.get_number_of_parameters()+1); // +1 for degree
	// coordinates_container.reserve(fixed_values.size()+1);
	// coordinates_container.push_back(0); // degree
	// for (auto c : fixed_values) coordinates_container.push_back(c);
	std::pair<Simplex_tree_std, std::vector<index_type>> thread_data_initialization = {_st,coordinates_container};
	const int max_dim = expand_collapse ? *std::max_element(degrees.begin(), degrees.end()) +1 : 0;
	tbb::enumerable_thread_specific<std::pair<Simplex_tree_std, std::vector<index_type>>> thread_simplex_tree(thread_data_initialization); // this has a fixed size, so init should be benefic
	_rec_get_hilbert_surface(st_multi, thread_simplex_tree,out,grid_shape, degrees,coordinates_to_compute, fixed_values, mobius_inverion,zero_pad, max_dim);

}




template<typename dtype=int, typename indices_type>
std::pair<std::vector<std::vector<indices_type>>, std::vector<dtype>> get_hilbert_signed_measure(
	const intptr_t simplextree_ptr, 
	dtype* data_ptr, 
	std::vector<indices_type> grid_shape,
	const std::vector<indices_type> degrees,
	bool zero_pad=false,
	indices_type n_jobs=0,
	const bool verbose = false,
	const bool expand_collapse=false
	){
	if (degrees.size() == 0) return {{},{}};
	// const bool verbose = false;
	auto &st_multi = get_simplextree_from_pointer<interface_multi>(simplextree_ptr);
	tensor::static_tensor_view<dtype, indices_type> container(data_ptr,grid_shape); // assumes its a zero tensor
	std::vector<indices_type> coordinates_to_compute(st_multi.get_number_of_parameters());
	for (auto i=0u; i< coordinates_to_compute.size(); i++) coordinates_to_compute[i] = i;
	// for (auto [c,i] : std::views::zip(coordinates_to_compute, std::views::iota(0,st_multi.get_number_of_parameters()))) c=i; // NIK apple clang
	std::vector<indices_type> fixed_values(st_multi.get_number_of_parameters());
	
	if (verbose){
		std::cout << "Container shape : ";
		for (auto r : container.get_resolution()) std::cout << r << ", ";
		std::cout << "\nContainer size : " << container.size();
		std::cout << "\nComputing hilbert invariant ...";
	}
	if (zero_pad){
		// +1 is bc degree is on first axis.
		for (auto i=1; i< st_multi.get_number_of_parameters()+1; i++) grid_shape[i]--; // get hilbert surface computes according to grid_shape.
		// for (auto i : std::views::iota(1,st_multi.get_number_of_parameters()+1)) grid_shape[i]--; // get hilbert surface computes according to grid_shape.
	}

	oneapi::tbb::task_arena arena(n_jobs); // limits the number of threads
	arena.execute([&]{
		get_hilbert_surface(st_multi,container, grid_shape, degrees, coordinates_to_compute, fixed_values, true, zero_pad, expand_collapse);
	});


	if (verbose){
		std::cout << "Done." << std::endl;
		std::cout << "Computing mobius inversion ...";
	}
	
	// for (indices_type axis : std::views::iota(2,st_multi.get_number_of_parameters()+1)) // +1 for the degree in axis 0
	for (indices_type axis=2u; axis< st_multi.get_number_of_parameters()+1; axis++)
		container.differentiate(axis);
	if (verbose){
		std::cout << "Done." << std::endl;
		std::cout << "Sparsifying the measure ...";
	}
	auto raw_signed_measure = container.sparsify();
	if (verbose){
		std::cout << "Done." << std::endl;
	}
	return raw_signed_measure;
}




template<typename dtype, typename indices_type, typename ... Args>
void get_hilbert_surface_python(
	const intptr_t simplextree_ptr, 
	dtype* data_ptr, 
	std::vector<indices_type> grid_shape,
	const std::vector<indices_type> degrees,
	const bool mobius_inversion,
	const bool zero_pad,
	indices_type n_jobs,
	bool expand_collapse){
	const bool verbose=false;
	if (degrees.size() == 0) return ;
	// const bool verbose = false;
	auto &st_multi = get_simplextree_from_pointer<interface_multi>(simplextree_ptr);
	tensor::static_tensor_view<dtype, indices_type> container(data_ptr,grid_shape); // assumes its a zero tensor
	std::vector<indices_type> coordinates_to_compute(st_multi.get_number_of_parameters());
	for (auto i=0u; i< coordinates_to_compute.size(); i++) coordinates_to_compute[i] = i;
	// for (auto [c,i] : std::views::zip(coordinates_to_compute, std::views::iota(0,st_multi.get_number_of_parameters()))) c=i; // NIK apple clang
	std::vector<indices_type> fixed_values(st_multi.get_number_of_parameters());
	
	if (verbose){
		std::cout << "Container shape : ";
		for (auto r : container.get_resolution()) std::cout << r << ", ";
		std::cout << "\nContainer size : " << container.size();
		std::cout << "\nComputing hilbert invariant ...";
	}
	if (zero_pad){
		// +1 is bc degree is on first axis.
		for (auto i=1; i< st_multi.get_number_of_parameters()+1; i++) grid_shape[i]--; // get hilbert surface computes according to grid_shape.
		// for (auto i : std::views::iota(1,st_multi.get_number_of_parameters()+1)) grid_shape[i]--; // get hilbert surface computes according to grid_shape.
	}

	oneapi::tbb::task_arena arena(n_jobs); // limits the number of threads
	arena.execute([&]{
		get_hilbert_surface(st_multi,container, grid_shape, degrees, coordinates_to_compute, fixed_values, mobius_inversion, zero_pad, expand_collapse);
	});

	if (mobius_inversion)
		for (indices_type axis=2u; axis< st_multi.get_number_of_parameters()+1; axis++)
			container.differentiate(axis);
	return;
}








} // namespace rank_invariant
