#include "algebra.h"

#include <cassert>
#include <map>

std::ostream& operator<< (std::ostream &os, const Variable &var) {
	return var.print(os);
}

Monomial::Monomial () : variables(0), coefficient(1) {}

void Monomial::setCoefficient (int coefficient) {
	this->coefficient = coefficient;
}

void Monomial::addFactor (const std::shared_ptr<Variable> &variable, unsigned exponent) {
	this->variables.emplace_back(variable, exponent);
}

void Monomial::popFactor () {
	this->variables.pop_back();
}

std::ostream& operator<< (std::ostream &os, const Monomial &monomial) {
	if (monomial.coefficient < 0) {
		os << "(" << monomial.coefficient << ")";
	} else if (monomial.coefficient != 1) {
		os << monomial.coefficient;
	}
	for (const auto &[base, exponent] : monomial.variables) {
		os << *base;
		if(exponent != 1) os << "^" << exponent;
	}
	return os;
}

long long multiply (long long a, long long b, long long prime) {
	__int128 A = a, B = b; /* use something else for compilers other than g++ */
	return (long long)(A * B % prime);
}

long long fastExp (long long a, long long b, long long p) {
	if (b == 0ll) {
		return 1ll;
	}
	long long t = fastExp(a, b/2, p);
	t = multiply(t, t, p);
	if (b % 2 == 1) {
		t = multiply(t, a, p);
	}
	return t;
}

long long modInv (long long a, long long p) {
	return fastExp(a, p-2, p);
}

Polynomial::Polynomial () : monomials(0) {}

void Polynomial::addMonomial (const Monomial &monomial) {
	this->monomials.push_back(monomial);
}

std::ostream& operator<< (std::ostream &os, const Polynomial &polynomial) {
	bool first = true;
	for (const Monomial &monomial : polynomial.monomials) {
		if (!first) os << " + ";
		else first = false;
		os << monomial;
	}
	if (first) {
		os << "0";
	}
	return os;
}

FASTP::FASTP () {}

FASTP::FASTP (long long p, long long q, long long r, long long s, long long t)
	: p(p), q(q), r(r), s(s), t(t) {}

long long FASTP::getP () const {
	return this->p;
}

long long FASTP::getQ () const {
	return this->q;
}

long long FASTP::getR () const {
	return this->r;
}

long long FASTP::getS () const {
	return this->s;
}

long long FASTP::getT () const {
	return this->t;
}

void FASTP::setP (long long p) {
	this->p = p;
}

void FASTP::setQ (long long q) {
	this->q = q;
}

void FASTP::setR (long long r) {
	this->r = r;
}

void FASTP::setS (long long s) {
	this->s = s;
}

void FASTP::setT (long long t) {
	this->t = t;
}

std::ostream& operator<< (std::ostream &os, const FASTP &fastp) {
	os << "(" << fastp.p << " + " << fastp.q << " sqrt(" << fastp.s << "))";
	os << "/(" << fastp.r << " + " << fastp.t << " sqrt(" << fastp.s << "))";
	return os;
}

long long getA (const FASTP &i, const FASTP &j, long long prime,
		long long sigma_pq, long long sigma_iq, long long sigma_pj, long long sigma_ij) {

	long long p = i.getP(), q = i.getQ(), r = i.getR(), s = i.getS(), t = i.getT();
	long long P = j.getP(), Q = j.getQ(), R = j.getR(), T = j.getT();

	long long result = multiply(multiply(p, P, prime), sigma_pq, prime);
	result = (result - multiply(multiply(p, R, prime), sigma_pj, prime) + prime) % prime;
	result = (result - multiply(multiply(r, P, prime), sigma_iq, prime) + prime) % prime;
	result = (result + multiply(multiply(r, R, prime), sigma_ij, prime)) % prime;
	result = (result + multiply(multiply(multiply(q, Q, prime), s, prime), sigma_pq, prime)) % prime;
	result = (result - multiply(multiply(multiply(q, T, prime), s, prime), sigma_pj, prime) + prime) % prime;
	result = (result - multiply(multiply(multiply(t, Q, prime), s, prime), sigma_iq, prime) + prime) % prime;
	return (result + multiply(multiply(multiply(t, T, prime), s, prime), sigma_ij, prime)) % prime;
}

long long getB (const FASTP &i, const FASTP &j, long long prime,
		long long sigma_pq, long long sigma_iq, long long sigma_pj, long long sigma_ij) {
	
	long long p = i.getP(), q = i.getQ(), r = i.getR(), t = i.getT();
	long long P = j.getP(), Q = j.getQ(), R = j.getR(), T = j.getT();

	long long result = multiply(multiply(p, Q, prime), sigma_pq, prime);
	result = (result + multiply(multiply(P, q, prime), sigma_pq, prime)) % prime;
	result = (result - multiply(multiply(p, T, prime), sigma_pj, prime) + prime) % prime;
	result = (result - multiply(multiply(q, R, prime), sigma_pj, prime) + prime) % prime;
	result = (result - multiply(multiply(t, P, prime), sigma_iq, prime) + prime) % prime;
	result = (result - multiply(multiply(r, Q, prime), sigma_iq, prime) + prime) % prime;
	result = (result + multiply(multiply(r, T, prime), sigma_ij, prime)) % prime;
	return (result + multiply(multiply(R, t, prime), sigma_ij, prime)) % prime;
}

Fraction::Fraction () {}

Polynomial& Fraction::getP () {
	return this->p;
}

Polynomial& Fraction::getQ () {
	return this->q;
}

void Fraction::setP (const Polynomial &p) {
	this->p = p;
}

void Fraction::setQ (const Polynomial &q) {
	this->q = q;
}

std::ostream& operator<< (std::ostream &os, const Fraction &fraction) {
	return os << "(" << fraction.p << ")/(" << fraction.q << ")";
}

Cycle::Cycle (const std::vector<size_t> &nodes) : nodes(nodes), identifiability(twoIdentifiable) {}

void Cycle::setTwoIdentifiable () {
	this->identifiability = this->twoIdentifiable;
}

void Cycle::setOneIdentifiableAZero (size_t reasonI, size_t reasonJ) {
	this->identifiability = this->oneIdentifiableAZero;
	this->reasonI = reasonI;
	this->reasonJ = reasonJ;
}

void Cycle::setOneIdentifiableDiscriminantZero () {
	this->identifiability = this->oneIdentifiableDiscriminantZero;
}

void Cycle::setOneIdentifiableOneOption (size_t reasonI, size_t reasonJ) {
	this->identifiability = this->oneIdentifiableOneOption;
	this->reasonI = reasonI;
	this->reasonJ = reasonJ;
}

bool Cycle::isTwoIdentifiable () const {
	return this->identifiability == this->twoIdentifiable;
}

std::ostream& operator<< (std::ostream &os, const Cycle &cycle) {
	for (size_t v : cycle.nodes) {
		os << v << "-";
	}
	os << cycle.nodes[0];
	if (!cycle.isTwoIdentifiable()) {
		os << ". Reason: ";
		switch(cycle.identifiability) {
			case cycle.oneIdentifiableAZero:
				os << "a = 0 for λ_" << cycle.reasonI << "," << cycle.reasonJ;
				break;
			case cycle.oneIdentifiableDiscriminantZero:
				os << "Δ = 0";
				break;
			case cycle.oneIdentifiableOneOption:
				os << "Equation for missing edge {" << cycle.reasonI << ", " << cycle.reasonJ
					<< "} only satisfied by one of the options";
				break;
			default:
				assert(false);
		}
	}
	return os;
}

Path::Path (const std::vector<size_t> &nodes) : nodes(nodes) {}

size_t Path::getBack () const {
	return this->nodes.back();
}

std::ostream& operator<< (std::ostream &os, const Path &path) {
	for (size_t v : path.nodes) {
		os << "->" << v;
	}
	return os;
}

Lambda::Lambda (const std::shared_ptr<RandomTool> &randomTool, size_t x, size_t y) : x(x), y(y), value(randomTool->getRandomSmallerPrime()) {}

Omega::Omega (const std::shared_ptr<RandomTool> &randomTool, size_t x, size_t y) : x(x), y(y), value(randomTool->getRandomSmallerPrime()) {}

Sigma::Sigma (const std::shared_ptr<RandomTool> &randomTool, size_t x, size_t y) : randomTool(randomTool), x(x), y(y), inTermsOfLambdaAndOmega(0) {}

std::ostream& Lambda::print (std::ostream &os) const {
	return os << "λ_" << this->x << "," << this->y;
}

std::ostream& Omega::print (std::ostream &os) const {
	return os << "ω_" << this->x << "," << this->y;
}

std::ostream& Sigma::print (std::ostream &os) const {
	return os << "σ_" << this->x << "," << this->y;
}

long long Lambda::getValue () const {
	return this->value;
}

long long Omega::getValue () const {
	return this->value;
}

void Sigma::setInTermsOfLambdaAndOmega (long long value) {
	this->inTermsOfLambdaAndOmega = value;
}

long long Sigma::getInTermsOfLambdaAndOmega () const {
	return this->inTermsOfLambdaAndOmega;
}
