package bittwiddled;

import static bittwiddled.BankerOrder.INTEGER.lessThan;
import static bittwiddled.BankerOrder.INTEGER.moreThan;

import java.util.Iterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.NavigableSet;
import java.util.Set;
import java.util.Map.Entry;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
/**
 * The static calculations used when inverting a {@link Dynamic}.
 * <p>
 * Most of the methods have a variety of method signatures, to provide for calling them with common data
 * structures without having to provide an adaptor.
 * <p>
 * Many of the methods also have side-effects on their arguments.  These side-effects are noted in
 * individual method documentation, but can be summarized: most of the methods that receive "reducible"
 * arguments will reduce those arguments as a side-effect.
 * @author Carl A. Pearson
 *
 */
public class InversionTools {

	private InversionTools() {}
		
	/**
	 * Given a series of states that come before an active state, and known non-inhibiting edges
	 * calculates the non-inhibiting edges from a series of states.
	 * @param states the series of states
	 * @return the 1 bits are the edges that cannot be inhibiting
	 */
	public static int calcNonInhibition(int knownNonInhibition, Iterable<Integer> states) {
		for (int state : states) knownNonInhibition |= state;
		return knownNonInhibition;
	}
	
	/**
	 * @see #calcNonInhibition(int, Iterable)
	 */
	public static int calcNonInhibition(int knownNonInhibition, int...states) {
		for (int state : states) knownNonInhibition |= state;
		return knownNonInhibition;
	}
	
	/**
	 * Assumes there are no known non-inhibitors.
	 * @see #calcNonInhibition(int, Iterable)
	 */
	public static int calcNonInhibition(Iterable<Integer> states) { return calcNonInhibition(0,states); }

	/**
	 * Assumes there are no known non-inhibitors.
	 * @see #calcNonInhibition(int, Iterable)
	 */
	public static int calcNonInhibition(int...states) { return calcNonInhibition(0,states); }

	/**
	 * Assumes there are no known non-inhibitors.  Gets the states from {@code b}.
	 * @see #calcNonInhibition(int, Iterable)
	 */
	public static int calcNonInhibition(BeforeStates b) { return calcNonInhibition(b.beforeON()); }
	
	/**
	 * Gets the states from {@code b}.
	 * @see #calcNonInhibition(int, Iterable)
	 */
	public static int calcNonInhibition(int knownNonInhibition, BeforeStates b) { return calcNonInhibition(knownNonInhibition, b.beforeON()); }
	
	/**
	 * Given a series of states that come before an inhibited state,
	 * calculate the non-activating edges.
	 * <p>
	 * The series of states is initially an identity map.  The keys
	 * have the nonInhibition edges removed afterwards; any key that
	 * goes to 0 is removed
	 * @param states the series of states, as an identity map
	 * @param nonInhibition the non-inhibiting edges
	 * @return the 1 bits are the edges that cannot be activating
	 */
	public static int calcNonActivation(Map<Integer,Integer> states, int nonInhibition) {
		int nonActivation = 0;
		nonInhibition = ~nonInhibition;
		Iterator<Integer> inhibitionTerms = states.keySet().iterator();
		Map<Integer,Integer> replace = Maps.newHashMap();
		
		while (inhibitionTerms.hasNext()) {
			int state = inhibitionTerms.next();
			int blend = state & nonInhibition;
			
			if (blend == 0) {
				nonActivation |= state;
				inhibitionTerms.remove();
				continue;
			}
			
			if (Integer.bitCount(blend) == 1) nonActivation |= blend;
			
			if (blend != state) {
				inhibitionTerms.remove();
				Integer put = replace.get(blend);
				replace.put(blend, (put==null ? state : put | state));
			}
						
		}
		
		for (Entry<Integer,Integer> term : replace.entrySet()) {
			Integer put = states.get(term.getKey());
			states.put(term.getKey(), (put==null ? term.getValue() : put | term.getValue()));
		}
		
//		System.out.println(source.toString());
		
		return nonActivation;

	}
	
	/**
	 * Gets {@code states} from {@code b}.
	 * @see #calcActivation(NavigableSet, int)
	 */
	public static int calcNonActivation(BeforeStates b, int nonInhibition) { return calcNonActivation(b.beforeOFF(),nonInhibition); }

	public static int calcActivation(NavigableSet<Integer> activationTerms, int nonActivation) {
		
		int activation = 0;
		if (activationTerms.isEmpty()) return activation;
		
		Iterator<Integer> terms = activationTerms.iterator();
		Set<Integer> replace = Sets.newHashSetWithExpectedSize(activationTerms.size());
		while (terms.hasNext()) {
			int term = terms.next();
			int blend = term & ~nonActivation;
			if (Integer.bitCount(blend) == 1) {
				activation |= blend;
				terms.remove();
			} else if (blend != term) {
				replace.add(blend);
				terms.remove();
			}
		}
		
		if (!replace.isEmpty()) 
			if (activation != 0) {
				for (int term : replace) if ((term & activation) == 0) activationTerms.add(term);
			} else {
				activationTerms.addAll(replace);
			}
		
		reduceTerms(activationTerms, activation);
		
		return activation;
		
	}

	public static int calcActivation(BeforeStates b, int nonActivation) {
		return calcActivation(b.beforeON(),nonActivation);
	}
		
	//might be worthwhile to do this a reverse iteration - performance test?
	public static void reduceTerms(NavigableSet<Integer> terms) {
		if (terms.isEmpty()) return;
		
		Integer term = terms.first();
				
		int count = Integer.bitCount(term);
		
		while ( count < Integer.bitCount(terms.last()) ) {
			Iterator<Integer> tail = terms.tailSet(moreThan(count)).iterator();
			while (tail.hasNext()) if ( (tail.next() & term) == term ) tail.remove();

			if ((term = terms.higher(term)) == null) break;
									
			count = Integer.bitCount(term);
			
		}
	}
	
	public static void reduceTerms(NavigableSet<Integer> terms, int fixed) {
		if (fixed==0) {
			reduceTerms(terms);
			return;
		}
		
		if (terms.isEmpty()) return;
		
		int term = nextTerm(terms,fixed);
				
		if (terms.isEmpty()) return;
		
		int count = Integer.bitCount(term);
		int mark = moreThan(count);
		
		while ( count < Integer.bitCount(terms.last()) ) {
			Iterator<Integer> tail = terms.tailSet(mark).iterator();
			while (tail.hasNext()) {
				int other = tail.next();
				if ((other & fixed) != 0 || (other & term) == term) tail.remove();
			}

			term = nextTerm(terms.tailSet(term, false), fixed);
			
			if (count < Integer.bitCount(term)) {
				count = Integer.bitCount(term);
				mark = moreThan(count);
			}
		}
	}
	
	public static int nextTerm(Iterable<Integer> terms, int fixed) {
		int term = 0;
		Iterator<Integer> nextTerm = terms.iterator();
		while ( nextTerm.hasNext() && ( (term = nextTerm.next()) &fixed )!=0 ) nextTerm.remove();
		return term;
	}
	
	public static int nextTerm(Iterator<Integer> terms, int fixed) {
		int term = 0;
		while ( terms.hasNext() && ( (term = terms.next()) &fixed )!=0 ) terms.remove();
		return term;
	}

	public static int calcInhibition(BeforeStates b, int knownActivation, int knownNonActivation) {
		return calcInhibition(b.beforeOFF(), b.requiredBeforeOFF(), b.beforeON(), knownActivation, knownNonActivation);
	}
	
	public static int calcInhibition(
		NavigableMap<Integer,Integer> states,
		NavigableSet<Integer> reducedInhibition,
		NavigableSet<Integer> reducedActivation,
		int knownActivation,
		int knownNonActivation) {
		
		int knownInhibition = 0;
		
		Iterator<Entry<Integer,Integer>> inhibitionTerms = states.entrySet().iterator();
		
		while (inhibitionTerms.hasNext()) {
			Entry<Integer,Integer> term = inhibitionTerms.next();
			int gNegTerm = term.getValue();
			if ((gNegTerm & knownActivation) != 0 || reducedActivation.contains(gNegTerm)) {
				int rTerm = term.getKey();
				if (Integer.bitCount(rTerm) == 1) {
					knownInhibition |= rTerm;
				} else reducedInhibition.add(term.getKey());
				inhibitionTerms.remove();
			} else for (int gTerm : reducedActivation.headSet(lessThan(Integer.bitCount(gNegTerm)),true)) if ((gNegTerm & gTerm) == gTerm) {
				int rTerm = term.getKey();
				if (Integer.bitCount(rTerm) == 1) {
					knownInhibition |= rTerm;
				} else reducedInhibition.add(term.getKey());
				inhibitionTerms.remove();
				break;
			}
		}
		
		reduceTerms(reducedInhibition,knownInhibition);
		
		reduceMappings(states,reducedInhibition,knownInhibition,knownNonActivation);
		
		return knownInhibition;
	}

	public static void reduceMappings(
			NavigableMap<Integer, Integer> states,
			NavigableSet<Integer> reducedInhibition, 
			int knownInhibition,
			int knownNonActivation) {
		
		//TODO making multiple passes - single pass avoids duplicate iterator()/next() costs?
		
		if (knownInhibition == 0) {
			reduceMappings(states,reducedInhibition);
		} reduceMappings(states,reducedInhibition,knownInhibition);
		
		mappingSelfReduce(states, knownNonActivation);

	}
	
	public static void reduceMappings(
			NavigableMap<Integer, Integer> states,
			NavigableSet<Integer> reducedInhibition) {
		
		if (!states.isEmpty() && !reducedInhibition.isEmpty()) {
			Iterator<Integer> terms = states.descendingKeySet().iterator(); //in reverse order
			int rTerm = terms.next();

			int count = Integer.bitCount(rTerm);
			NavigableSet<Integer> slice = reducedInhibition.headSet(lessThan(count),true);
				
			while (terms.hasNext() && !slice.isEmpty()) {
				
				if (reducedInhibition.contains(rTerm)) {
					terms.remove();
				} else  for (int reqRTerm : slice) if ((reqRTerm & rTerm) == reqRTerm) {
					terms.remove();
					break;
				}

				rTerm = terms.next();
				
				if ( Integer.bitCount(rTerm) < count) {
					count = Integer.bitCount(rTerm);
					slice = slice.headSet(lessThan(count), true);
				}

			}
		
		}
		
	}
	
	public static void reduceMappings(
			NavigableMap<Integer, Integer> states,
			NavigableSet<Integer> reducedInhibition,
			int knownInhibition) {
		
		if (states.isEmpty()) return;

		Iterator<Integer> terms = states.descendingKeySet().iterator(); //in reverse order

		int rTerm = nextTerm(terms,knownInhibition);
		
		if (rTerm==0) return;
			
		int count = Integer.bitCount(rTerm);
		NavigableSet<Integer> slice = reducedInhibition.headSet(lessThan(count),true);
		
		do {
				
			for (int reqRTerm : slice) if ((reqRTerm & rTerm) == reqRTerm) {
				terms.remove();
				break;
			}
				
			rTerm = nextTerm(terms, knownInhibition);
				
			if ( Integer.bitCount(rTerm) < count) {
				count = Integer.bitCount(rTerm);
				slice = slice.headSet(lessThan(count), true);
			}

		} while (terms.hasNext() && !slice.isEmpty());
				
	}

	/**
	 * Takes a {@link NavigableMap} of inhibition terms to activation terms and reduces it.
	 * <p>
	 * Each key removes the contents of its value from all super-keys' values.  If a key's value
	 * is emptied, then that key is removed.  The value of {@code knownNonActivation} is also removed
	 * from each value.
	 * <p>
	 * Algorithmically, this occurs by traversing the Map keys in reverse (since it should be sorted from smallest
	 * key to largest key) and getting the keys smaller than the current key (i.e., the headMap for keys with counts
	 * less than the current key's count).
	 * <p>
	 * Having a key == 0 in the Map has the identical effect to specifying that key's value as {@code knownNonActivation}.
	 * @param mixedTerms a map linking inhibition terms to activation terms, navigable by inhibition term size, from smallest to largest size
	 * @param knownNonActivation the known non-activation terms
	 */
	public static void mappingSelfReduce(NavigableMap<Integer,Integer> mixedTerms, int knownNonActivation) {
		if (mixedTerms.size() < 2) { //no actual knocking out to do
			if (!mixedTerms.isEmpty()) mixedTerms.put(mixedTerms.firstKey(), mixedTerms.firstEntry().getValue() & ~knownNonActivation);
			return;
		}
		
		Map<Integer,Integer> backfill = Maps.newHashMap();
		
		Iterator<Integer> rTerms = mixedTerms.descendingKeySet().iterator();
		int count = Integer.SIZE;
		NavigableSet<Integer> headTerms = mixedTerms.navigableKeySet();
		knownNonActivation = ~knownNonActivation;

		
		while (rTerms.hasNext()) {
			int rTerm = rTerms.next();
			int gNeg = mixedTerms.get(rTerm) & knownNonActivation;
			
			if (gNeg == 0) {
				rTerms.remove();
				continue;
			}
			
			if (Integer.bitCount(rTerm) < count) {
				count = Integer.bitCount(rTerm);
				headTerms = headTerms.headSet(lessThan(count), true);
			}

			for (int subRTerm : headTerms) if ((subRTerm & rTerm) == subRTerm) {
				gNeg &= ~mixedTerms.get(subRTerm);
				if (gNeg == 0) {
					rTerms.remove();
					break;
				} 
			}
			
			if ( (gNeg & rTerm) != 0) {
				if (Integer.bitCount(gNeg)==1) {
					rTerms.remove();
					rTerm = rTerm ^ gNeg;
					if (headTerms.contains(rTerm)) {
						mixedTerms.put(rTerm, mixedTerms.get(rTerm) | gNeg);
					} else {
						Integer put = backfill.get(rTerm);
						backfill.put(rTerm, put == null ? gNeg : gNeg | put);
					}
				} 
//				else mixedTerms.put(rTerm, gNeg);
				else {
					int ref = gNeg & rTerm;
					while (ref != 0) {
						int mask = BankerOrder.INTEGER.bitmask(Integer.numberOfTrailingZeros(ref));
						ref &= ~mask;
						if (mixedTerms.get(rTerm ^ mask)!=null) {
							gNeg ^= mask;
							mixedTerms.put(rTerm ^ mask, mixedTerms.get(rTerm ^ mask) | mask);
						} else if (backfill.get(rTerm ^ mask) !=null) {
							gNeg ^= mask;
							backfill.put(rTerm ^ mask, backfill.get(rTerm ^ mask) | mask);
						}
					}
					if (gNeg==0) {
						rTerms.remove();
					} else mixedTerms.put(rTerm, gNeg);
				}
			} else if (gNeg!=0) {
				mixedTerms.put(rTerm, gNeg);
			}
			
		}
		
		mixedTerms.putAll(backfill);
		
	}
	
}
