#include <fstream>
#include <limits>
#include <phat/boundary_matrix.h>

#include <scc/Scc.h>

#include <unordered_map>

typedef scc::Scc<> Scc;

typedef std::vector<std::vector<long>> Induced_matrix;

template< typename Matrix >
void standard_reduction ( Matrix& boundary_matrix ) {

  typedef long index;
  const index nr_columns = boundary_matrix.get_num_cols();
  std::unordered_map< index,index > lowest_one_lookup;
            
  for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
    //std::cout << "REDUCE " << cur_col << std::endl;
    index lowest_one = boundary_matrix.get_max_index( cur_col );
    //std::cout << "Lowest one " << lowest_one << std::endl;
    while( lowest_one != -1 && lowest_one_lookup.count(lowest_one) ) {
      //std::cout << "Adding " << lowest_one_lookup[ lowest_one ] << std::endl;
      boundary_matrix.add_to( lowest_one_lookup[ lowest_one ], cur_col );
      lowest_one = boundary_matrix.get_max_index( cur_col );
    }
    if( lowest_one != -1 ) {
      
      lowest_one_lookup[ lowest_one ] = cur_col;
      //std::cout << "Setting lookup of " << lowest_one << " to " << cur_col << std::endl;
    }
    boundary_matrix.finalize( cur_col );
  }
}

template<typename Matrix>
void prune_zero_columns ( Matrix& m ) {

  int pos_of_next=0;
  for(int i=0;i<m.get_num_cols();i++) {
    if(m.is_empty(i)) {
      continue;
    }
    if(i==pos_of_next) {
      pos_of_next++;
      continue;
    }
    phat::column col;
    m.get_col(i,col);
    m.set_col(pos_of_next,col);
    m.set_dim(pos_of_next,m.get_dim(i));
    pos_of_next++;
  }
  m.set_num_cols(pos_of_next);
  
}

void _determine_slice_values(Scc& parser,
			     int level,
			     int slices,
			     int primary_parameter,
			     std::set<double>& grades,
			     double& min_gr,
			     double& max_gr) {
  std::vector<double> gr;
  std::vector<std::pair<int,int>> coeff;
  
  while(parser.has_next_column(level)) {
    gr.clear();
    coeff.clear();
    parser.next_column(level,std::back_inserter(gr),std::back_inserter(coeff));
    double new_grade = gr[primary_parameter];
    if(slices==0) {
      grades.insert(new_grade);	
    } else {
      min_gr=std::min(min_gr,new_grade);
      max_gr=std::max(max_gr,new_grade);
    }
  }
  parser.reset(level);
}

template <typename OutputIterator> void determine_slice_values(Scc& parser,
							       int level,
							       int slices,
							       int primary_parameter,
							       OutputIterator out) {
  // Needed only for slices==0
  std::set<double> grades;
  // Needed only for slices>0
  double min_gr=std::numeric_limits<double>::max();
  double max_gr=std::numeric_limits<double>::min();

  _determine_slice_values(parser,level,slices,primary_parameter,grades,min_gr,max_gr);
  _determine_slice_values(parser,level+1,slices,primary_parameter,grades,min_gr,max_gr);

  if(slices==0) {
    std::copy(grades.begin(),grades.end(),out);
  } else {
    assert(max_gr>=min_gr);
    double delta = max_gr-min_gr;
    
    for(int i=0;i<slices;i++) {
      // Guard the case slices==1 (even though it is a useless case)
      if(i==0) {
	*out++=min_gr;
      }	else {
	*out++ = min_gr+(double)i/(slices-1)*delta;
      }
    }
    
  }

  parser.reset(level);
  
}

struct Generator {

  double gr[2];
  int pos_in_scc;
  int index;

  Generator(double d1, double d2,int pos) {
    gr[0]=d1;
    gr[1]=d2;
    pos_in_scc=pos;
    // index is set during algorithm
  }
};

struct Relation {

  double gr[2];
  std::vector<long> bd;

  Relation(double d1, double d2,std::vector<int>& bdy) {
    gr[0]=d1;
    gr[1]=d2;
    std::copy(bdy.begin(),bdy.end(),std::back_inserter(bd));
  }
};

struct Sort_by_secondary {
  
  int secondary_parameter;
  
  Sort_by_secondary(int primary_parameter) : secondary_parameter(1-primary_parameter) {}

  template<typename T>
  bool operator() (T& g1, T& g2) {
    return g1.gr[secondary_parameter] < g2.gr[secondary_parameter];
  }
};

template<typename Generator_output_iterator, typename Relation_output_iterator>
void compute_generators_and_relations_sorted_by_secondary_parameter(Scc& parser,
								    int level,
								    int primary_parameter,
								    Generator_output_iterator out_gen,
								    Relation_output_iterator out_rel) {
  std::vector<double> gr;
  std::vector<std::pair<int,int>> coeff;
  
  // First the generators:

  std::vector<Generator> gens;
  
  int count=0;
  while(parser.has_next_column(level+1)) {
    gr.clear();
    coeff.clear();
    parser.next_column(level+1,std::back_inserter(gr),std::back_inserter(coeff));
    gens.push_back(Generator(gr[0],gr[1],count++));
  }
  std::sort(gens.begin(),gens.end(),Sort_by_secondary(primary_parameter));
  
  std::vector<int> gens_reindexing;
  gens_reindexing.resize(gens.size());
  
  for(int i=0;i<gens.size();i++) {
    Generator& gen = gens[i];
    gen.index=i;
    gens_reindexing[gen.pos_in_scc]=i;
    *out_gen++=gen;
  }

  // Now the relations:

  std::vector<Relation> rels;

  gr.clear();
  coeff.clear();
  
  while(parser.has_next_column(level)) {
    gr.clear();
    coeff.clear();
    parser.next_column(level,std::back_inserter(gr),std::back_inserter(coeff));
    std::vector<int> coeff_reindexed;
    for(std::pair<int,int> c : coeff) {
      coeff_reindexed.push_back(gens_reindexing[c.first]);
    }
    std::sort(coeff_reindexed.begin(),coeff_reindexed.end());
    rels.push_back(Relation(gr[0],gr[1],coeff_reindexed));
  }
  std::sort(rels.begin(),rels.end(),Sort_by_secondary(primary_parameter));
  
  std::copy(rels.begin(),rels.end(),out_rel);
			     
}

template<typename Matrix>
void kernel_basis_at(std::vector<Generator>& gens,
		     std::vector<Relation>& rels,
		     int primary_parameter,
		     double val,
		     Matrix& m) {

  std::cout << "Now produce boundary matrix at val " << val << std::endl;

  int secondary_parameter = 1-primary_parameter;

  // Count the number of relations
  int num_rels=0;
  for(Relation& rel : rels) {
    if(rel.gr[primary_parameter]<=val) {
      num_rels++;
    }
  }

  //std::cout << "Number of relations: " << num_rels << std::endl;

  m.set_num_cols(num_rels);
  // Collect all relations active at val and reduce
  int count=0;
  for(int i=0;i<rels.size();i++) {
    Relation& rel = rels[i];
    if(rel.gr[primary_parameter]<=val) {
      m.set_col(count,rel.bd);
      // We abuse the dimension field here to get the grade later
      m.set_dim(count,i);
      count++;
    }
  }

  //std::cout << "Before reduction, matrix has " << m.get_num_cols() << " columns" << std::endl;

  //std::cout << "Now reduce " << std::endl;
  standard_reduction(m);

  prune_zero_columns(m);

  //std::cout << "After pruning, matrix has " << m.get_num_cols() << " columns" << std::endl;

  std::set<int> paired_generators;

  int empty_columns=0;

  for(int i=0;i<m.get_num_cols();i++) {
    assert(! m.is_empty(i));
    int pivot = m.get_max_index(i);
    assert(paired_generators.count(pivot)==0);
    paired_generators.insert(pivot);
  }

  //std::cout << "Found " << paired_generators.size() << " paired generators " << std::endl;


  std::vector<int> unpaired_generators;
  
  for(Generator& gen : gens) {
    if(gen.gr[primary_parameter]<=val) {
      int index = gen.index;
      if(paired_generators.count(index)==0) {
	unpaired_generators.push_back(index);
      }
    } else {
      assert(paired_generators.count(gen.index)==0);
    }
  }

  //std::cout << "Found " << unpaired_generators.size() << " unpaired generators " << std::endl;      

  // Now extend the matrix with the unpaired generators to obtain a kernel basis
  int old_size=m.get_num_cols();
  m.set_num_cols(old_size+unpaired_generators.size());
  
  for(int i = old_size; i < m.get_num_cols(); i++) {
    int curr_index = unpaired_generators[i-old_size];
    std::vector<long> curr_bd;
    curr_bd.push_back(curr_index);
    m.set_col(i,curr_bd);
    // Stands for "infinity"
    m.set_dim(i,-1);

  }
  
  //std::cout << "Done, kernel basis has " << m.get_num_cols() << " elements" << std::endl;



}
			    
template<typename Matrix,typename Induced_matrix>
void compute_induced_matrix(Matrix& domain,
			    Matrix& basis,
			    Induced_matrix& induced_matrix) {

  typedef long index;
  const index nr_columns = basis.get_num_cols();

  //std::cout << "basis has " << nr_columns << "elements" << std::endl;

  std::unordered_map< index,index > lowest_one_lookup;
            
  for( index cur_col = 0; cur_col < nr_columns; cur_col++ ) {
    index lowest_one = basis.get_max_index( cur_col );
    assert(lowest_one!=-1);
    lowest_one_lookup[ lowest_one ] = cur_col;
  }

  basis.set_num_cols(nr_columns+1);
  
  phat::column col;

  for(int i=0;i<domain.get_num_cols();i++) {
    //std::cout << "i=" << i << " out of " << domain.get_num_cols() << std::endl;
    std::vector<index> linear_combination;
    domain.get_col(i,col);
    basis.set_col(nr_columns,col);
    index lowest_one = basis.get_max_index( nr_columns );
    //std::cout << "Lowest one: " << lowest_one << std::endl;
    assert(lowest_one!=-1);
    while( lowest_one != -1) {
      assert(lowest_one_lookup.count(lowest_one));
      //std::cout << "Adding " << lowest_one_lookup[ lowest_one ] << std::endl;
      index new_index = lowest_one_lookup[ lowest_one ];
      linear_combination.push_back(new_index);
      basis.add_to( new_index , nr_columns );
      lowest_one = basis.get_max_index( nr_columns );
    }
    //std::cout << "done reducing, column has size " << linear_combination.size() << std::endl;
    std::sort(linear_combination.begin(),linear_combination.end());
    induced_matrix.push_back(linear_combination);
  }
  basis.set_num_cols(nr_columns);
}
			    
void print_help_message(char* arg0) {

  std::cout << "Usage: " << arg0 << " [OPTIONS] <INPUT_FILE> <OUTPUT_FILE>\n\n";

  std::cout << "Computes the graphcode for a chain complex in scc2020 format. \n\n";

  std::cout << "Options:\n";
  std::cout << "--primary-parameter d       - Should be 0 or 1 (default is 0)\n";
  std::cout << "--slices n                  - Computes barcodes at n slices, which are taken equidistantly in the range scale of the primary parameter. 0 means that all slices are computed. Default is n=10\n";
  std::cout << "--slice-file filename       - Computes barcodes at slices whose value is determined by the file filename. If used, parameter \"--slices\" is ignored.\n";
  std::cout << "--include-infinite-bars     - takes infinite bars into account in the output\n";
  std::cout << "--keep-disjoint-pairs       - does not filter out edges between bars with disjoint lifetimes\n";
  std::cout << "--relevance-threshold v     - only include bars of persistence at least v in the output. v must be non-negative decimal\n";
  std::cout << "--level i                   - Computes the graphcode on level i of the input chain complex (default is 1)\n";
  std::cout << "-h --help                   - prints this message\n\n";
}

bool is_non_negative_integer(std::string arg) {
  
  for(int i=0;i<arg.length();i++) {
    if(! std::isdigit(arg[i])) {
      return false;
    }
  }
  return true;
}

bool is_non_negative_float(std::string arg) {
  
  for(int i=0;i<arg.length();i++) {
    if(! (std::isdigit(arg[i]) || arg[i]=='.')) {
      return false;
    }
  }
  return true;
}



template<typename OutputIterator>
void slice_values_from_file(std::string filename,OutputIterator ofstr) {
  std::ifstream ifstr(filename.c_str());
  double next;
  ifstr >> next;
  while(ifstr.good()) {
    *ofstr++=next;
    ifstr>>next;
  }
}


int main(int argc, char** argv) {

  int level=1;

  int slices=10;


  // This is the parameter that determines the slice values
  // That is, the other parameter is the filtration value along which
  // the barcodes are computed
  int primary_parameter = 1;

  bool include_infinite_bars=false;  

  // -1 stands for not set
  double relevance_threshold=-1;

  bool inputfile_read=false;
  bool outputfile_read=false;

  std::string infile,outfile;

  bool print_help=false;

  bool read_slice_values_from_file=false;
  std::string slice_value_file;

  bool filter_out_disjoint_pairs=true;
  
  int pos=1;
  
  while(pos<argc) {

    std::string arg(argv[pos]);
    
    if(arg=="-h" || arg=="--help") {
      print_help=true;
      break;
    } else if(arg=="--primary-parameter") {
      if(pos+1<argc && is_non_negative_integer(std::string(argv[pos+1]))) {
	primary_parameter=atoi(argv[++pos]);
      } else {
	std::cout << "--primary-parameter requires extra integer argument" << std::endl;
	print_help=true;
	break;
      }
    } else if(arg=="--level") {
      if(pos+1<argc && is_non_negative_integer(std::string(argv[pos+1]))) {
	level=atoi(argv[++pos]);
      } else {
	std::cout << "--level requires extra integer argument" << std::endl;
	print_help=true;
	break;
      }
    } else if(arg=="--slices") {
      if(pos+1<argc && is_non_negative_integer(std::string(argv[pos+1]))) {
	slices=atoi(argv[++pos]);
      } else {
	std::cout << "--primary-parameter requires extra integer argument" << std::endl;
	print_help=true;
	break;
      }
    } else if(arg=="--slice-file") {
      if(pos+1<argc) {
	read_slice_values_from_file=true;
	slice_value_file=std::string(argv[pos+1]);
	pos++;
      } else {
	std::cout << "--slice-file requires extra argument" << std::endl;
	print_help=true;
	break;
      }
    } else if(arg=="--include-infinite-bars") {
      include_infinite_bars=true;
    } else if(arg=="--keep-disjoint-pairs") {
      filter_out_disjoint_pairs=false;
    } else if(arg=="--relevance-threshold") {
      if(pos+1<argc && is_non_negative_float(std::string(argv[pos+1]))) {
	relevance_threshold=atof(argv[++pos]);
      } else {
	std::cout << "--relevance-threshold requires extra decimal argument" << std::endl;
	print_help=true;
	break;
      }
    } else {
      if(arg[0]=='-') {
	std::cout << "Unrecognized option: " << arg << std::endl;
	print_help=true;
	break;
      }
      if(! inputfile_read) {
	infile=arg;
	inputfile_read=true;
      } else if(! outputfile_read) {
	outfile=arg;
	outputfile_read=true;
      } else {
	std::cerr << "Ignoring argument " << arg << std::endl;
      }
    }
    pos++;
  }

  if(!print_help && !inputfile_read) {
    std::cout << "Input file missing" << std::endl;
    print_help=true;
  }

  if(!print_help && !outputfile_read) {
   std::cout << "Output file missing" << std::endl;
    print_help=true;
  } 


  if(print_help) {
    print_help_message(argv[0]);
    std::exit(0);
  }




  std::ifstream istr(infile);

  Scc parser(istr);

  istr.close();

  std::cout << "Number of parameters: " << parser.number_of_parameters() << std::endl;

  int levels = parser.number_of_levels();
  std::cout << "Number of levels: " << levels << std::endl;

  /*
  for(int i=1;i<=levels;i++) {
    std::cout << "Level " << i << ": " << parser.number_of_generators(i) << " generators" << std::endl;
  }
  */

  std::cout << "Relevance threshold: " << relevance_threshold << std::endl;

  std::vector<double> slice_values;

  if(read_slice_values_from_file) {
    slice_values_from_file(slice_value_file,std::back_inserter(slice_values));
    slices = slice_values.size();
  } else {
    determine_slice_values(parser, level, slices, primary_parameter, std::back_inserter(slice_values));
  }


  
  std::cout << "Considering level  " << level << std::endl;
  std::cout << "Number of slices:  " << slices << std::endl;
  std::cout << "Primary parameter: " << primary_parameter << std::endl;

  
  std::cout << "Number of slice values: " << slice_values.size() << std::endl;

  /*
  for(double gr : slice_values) {
    std::cout << "Grade: " << gr << std::endl;
  }
  */

  std::vector<Generator> gens;
  std::vector<Relation> rels;
  
  compute_generators_and_relations_sorted_by_secondary_parameter(parser,
								 level,
								 primary_parameter,
								 std::back_inserter(gens),
								 std::back_inserter(rels));

  std::cout << "We have " << gens.size() << " generators and " << rels.size() << " relations" << std::endl;
  
  /*
  for(Generator& gen : gens) {
    std::cout << gen.index << ": " << gen.gr[0] << " " << gen.gr[1] << std::endl;
  }

  for(Relation& rel : rels) {
    std::cout << "Relation: " << rel.gr[0] << " " << rel.gr[1] << "; ";
    for(int i : rel.bd) {
      std::cout << i << " ";
    }
    std::cout << std::endl;
  }
  */

  typedef phat::boundary_matrix<phat::vector_vector> Boundary_matrix;

  std::vector<Boundary_matrix> matrices;

  matrices.resize(slice_values.size());

  for(int i=0;i<slice_values.size();i++) {
    
    double grade = slice_values[i];
    
    kernel_basis_at(gens,rels,primary_parameter,grade,matrices[i]);
    
  }


  
  std::vector<Induced_matrix> induced_matrices;
  induced_matrices.resize(slice_values.size()-1);

  std::cout << "Compute induced matrices" << std::endl;

  for(int i=1;i<slice_values.size();i++) {
    //std::cout << "Compute induced matrix " << i-1 << " " << i << std::endl;

    compute_induced_matrix(matrices[i-1],matrices[i],induced_matrices[i-1]);
  }


  // Now compute the output
  
  // Every bar on every slice needs a unique id. We compute offsets
  std::vector<int> offsets;
  long akku=0;
  offsets.push_back(akku);
  for(int i=0;i<slice_values.size();i++) {
    akku+=matrices[i].get_num_cols();
    offsets.push_back(akku);
  }
  std::cout << "In total, we have " << offsets.back() << " generators." << std::endl;

  std::ofstream ofstr(outfile);
  
  ofstr.precision(std::numeric_limits<double>::max_digits10);

  long running_index=0;
  long no_relevant_pairs=0;
  std::unordered_map<long,long> relevant_indices;
  
  long last_offset=0;
  
  std::vector<std::pair<std::pair<double,double>,std::pair<int,double>>> relevant_pairs;
  for(int i=0;i<slice_values.size();i++) {
    Boundary_matrix& m = matrices[i];
    long no_relevant_pairs_on_level=0;

    for(int j=0;j<m.get_num_cols();j++) {
      long curr_lowest = m.get_max_index(j);
      double birth = gens[curr_lowest].gr[1-primary_parameter];
      // Now we abuse the dimension field
      long rel_index = m.get_dim(j);

      bool relevant=true;

      double death;
      if(rel_index>=0) {
	death = rels[rel_index].gr[1-primary_parameter];
      } else {
	if(include_infinite_bars) {
	  death = std::numeric_limits<double>::max();
	} else {
	  relevant=false;
	}
      }
	
      relevant = relevant && birth<death;

      if(relevant && relevance_threshold>0) {
	//std::cout << birth << " " << death << " " << death-birth << " " << relevance_threshold << std::endl;
	relevant = (death-birth>=relevance_threshold);
      }

      if(relevant) {
	relevant_indices[running_index]=no_relevant_pairs;
	no_relevant_pairs++;
	no_relevant_pairs_on_level++;
	relevant_pairs.push_back(std::make_pair(std::make_pair(birth,death),std::make_pair(i,slice_values[i])));
      }
      running_index++;
    }
    std::cout << "Number of relevant pairs on slice " << i << " = " << no_relevant_pairs_on_level << ", interval " << last_offset << "; " << last_offset+no_relevant_pairs_on_level-1 << std::endl;
    last_offset += no_relevant_pairs_on_level;
  }


  ofstr << no_relevant_pairs << std::endl;
  for(auto& pair : relevant_pairs) {
    ofstr << pair.first.first << " " << pair.first.second << " " << pair.second.first << " " << pair.second.second << std::endl;
  }

  // Now the edges.
  
  int no_of_disjoint_pairs = 0;

  int no_edges = 0;

  for(int i=1;i<slice_values.size();i++) {
    long offset_rows = offsets[i];
    long offset_columns = offsets[i-1];
    Induced_matrix& ind_mat = induced_matrices[i-1];
    for(int j=0;j<ind_mat.size();j++) {
      long source_index = offset_columns+j;
      for(long row_index : ind_mat[j]) {
	long target_index = offset_rows+row_index;
	if(relevant_indices.count(source_index) && relevant_indices.count(target_index)) {
	  int source_id = relevant_indices[source_index];
	  int target_id = relevant_indices[target_index];
	  auto source_pair = relevant_pairs[source_id];
	  auto target_pair = relevant_pairs[target_id];
	  assert(source_pair.first.first<=source_pair.first.second);
	  assert(target_pair.first.first<=target_pair.first.second);

	  bool disjoint = (source_pair.first.second<target_pair.first.first || target_pair.first.second < source_pair.first.first);
	  
	  if(disjoint) {
	    no_of_disjoint_pairs++;
	  }

	  if(! (disjoint && filter_out_disjoint_pairs) ) {
	    ofstr << source_id << " " << target_id << std::endl;
	    ofstr << target_id << " " << source_id << std::endl;
	    no_edges++;
	  }
	}
      }
    }
    
  }

  std::cout << "Number of edges between disjoint bars: " << no_of_disjoint_pairs << std::endl;

  std::cout << "Graph has " << relevant_pairs.size() << " vertices and " << no_edges << " edges." << std::endl;

  ofstr.close();
  

  return 0;

}
