import numpy as np
from random import shuffle
from math import comb
from games.oh_hell import Phase, Card
import algorithms.util as util

def try_build_new_state(state, deal):
	new_state = state.get_game().new_initial_state()
	num_tricks_actions = state.get_phase_actions(Phase.SELECT_NUM_TRICKS)
	initial_deal = state.get_phase_actions(Phase.DEAL)
	bid_actions = state.get_phase_actions(Phase.BID)
	cardplay_actions = state.get_phase_actions(Phase.CARDPLAY)
	new_history = num_tricks_actions + deal[0:len(initial_deal)] + bid_actions + cardplay_actions
	for action in new_history:
		player = new_state.get_player_to_move()
		if action not in new_state.get_legal_actions(player):
			return None
		try:
			new_state.play(player, action)
		except ValueError as ve:
			return None
	return new_state

def generate_neighbor(state):
	game = state.get_game()
	num_players = game.num_players
	num_suits = game.num_suits
	num_ranks = game.num_ranks
	played_cards = [state.get_played_cards(p) for p in range(num_players)]
	voids = [state.get_voids(p) for p in range(num_players)]
	deal = state.get_phase_actions(Phase.DEAL)
	initial_deal = deal.copy()
	for i in range(state.get_game().num_cards()):
		if i not in initial_deal:
			deal.append(i)
	swaps = [(1,1)]
	for i in range(1, len(initial_deal)):
		for j in range(i + 1, len(deal)):
			player_i = (i - 1) % num_players
			player_j = (j - 1) % num_players
			played_consistent = deal[i] not in played_cards[player_i] and deal[j] not in played_cards[player_j]
			void_consistent = Card.from_int(deal[j], num_suits, num_ranks).suit() not in voids[player_i]
			void_consistent = void_consistent and (Card.from_int(deal[i], num_suits, num_ranks).suit() not in voids[player_j] or j > len(initial_deal))
			if (player_i == player_j and j <= len(initial_deal)) or (played_consistent and void_consistent):
				swaps.append((i,j))
	shuffle(swaps)
	neighbor = None
	i = 0
	while not neighbor and i < len(swaps):
		j,k = swaps[i]
		candidate = deal.copy()
		candidate[j] = deal[k]
		candidate[k] = deal[j]
		neighbor = try_build_new_state(state, candidate)
		i += 1
	if neighbor:
		return neighbor
	else:
		raise ValueError("No neighbors to transition to.")


def count_histories(suit_config, suit_sums):
	num_suits = suit_sums.shape[1]
	num_players = (len(suit_config) / num_suits) - 1
	cards_used = [0.] * num_suits
	count = 1.
	for idx, suit_count in enumerate(suit_config):
		if idx >= num_players * num_suits:
			break
		suit = idx % num_suits
		count *= comb(int(suit_sums[0, suit] - cards_used[suit]), int(suit_count))
		cards_used[suit] += suit_count
	return count


def sample_history(suit_config, orig_state):
	game = orig_state.get_game()
	num_players = game.num_players
	num_suits = game.num_suits
	num_ranks = game.num_ranks
	player_deal = [orig_state.get_played_cards(p) for p in range(num_players)]
	rem_cards = [[] for _ in range(num_suits)]
	played_cards = orig_state.get_phase_actions(Phase.CARDPLAY)
	for i in range(game.num_cards()):
		c = Card.from_int(i, num_suits, num_ranks)
		if i not in played_cards and i != orig_state.trump().to_int(num_ranks):
			rem_cards[c.suit()].append(i)

	for p in range(num_players):
		for s in range(num_suits):
			count = int(suit_config[p * num_suits + s])
			# choose "count" cards of suit s to put in p's deal
			if count > 0:
				d = np.random.choice(rem_cards[s], size=count, replace=False)
				rem_cards[s] = [x for x in rem_cards[s] if x not in d]
				player_deal[p].extend(d)

	# make sure to put the deal in random order so any history can be generated
	for d in player_deal:
		shuffle(d)
	deal = [orig_state.trump().to_int(num_ranks)]
	for i in range(len(player_deal[0])):
		for j in range(num_players):
			deal.append(player_deal[j][i])

	# create a new state and set the deal and played cards to it
	return try_build_new_state(orig_state, deal)


# returns a sampled neighbor and the total number of neighbors to the current history
def generate_neighbor_ring_swap(state):
	# get the deal info and voids
	game = state.get_game()
	num_players = game.num_players
	num_suits = game.num_suits
	num_ranks = game.num_ranks
	played_cards = [state.get_played_cards(p) for p in range(num_players)]
	voids = [state.get_voids(p) for p in range(num_players)]
	voids.append([])
	deal = state.get_phase_actions(Phase.DEAL)
	rem_cards = []
	for i in range(game.num_cards()):
		if i not in deal:
			rem_cards.append(i)

	# get the suit length distribution for all players
	counts = [np.zeros(num_suits) for _ in range(num_players + 1)]
	for c in rem_cards:
		counts[-1][Card.from_int(c, num_suits, num_ranks).suit()] += 1.
	for i,c in enumerate(deal):
		# need to skip cards that are played
		player = (i - 1) % num_players
		if i > 0 and c not in played_cards[player]:
			counts[player][Card.from_int(c, num_suits, num_ranks).suit()] += 1.

	# BFS to depth max(num_suits, num_players configs), store all leaves as possible configurations,
	# doing ring swaps and stopping once initial swap column is returned to
	neighbors = set()
	sums = np.sum(np.matrix(counts), axis=0)
	neighbors.add(tuple(np.matrix(counts).A1))
	for r in range(len(counts)):
		for c1 in range(len(counts[0])):
			for c2 in range(c1+1, len(counts[0])):
				if c1 not in voids[r] and c2 not in voids[r] and counts[r][c2] > 0:
					m = np.matrix(counts)
					m[r, c1] += 1
					m[r, c2] -= 1
					ring_swap(m, voids, num_suits, sums, neighbors)

	# count number of histories per config as sampling weight for each config, sample a config
	p = []
	hist_sum = 0.
	configs = list(neighbors)
	for config in configs:
		num_hist = count_histories(config, sums)
		p.append(num_hist)
		hist_sum += num_hist
	for i in range(len(p)):
		p[i] /= hist_sum

	sample_config = configs[np.random.choice(range(len(configs)), p=p)]
	# sample a history uniformly from the config
	return sample_history(sample_config, state), hist_sum



def get_imbalanced_cols(m, c_sums):
	col_sums = np.sum(m, axis=0)[0]
	return [i for i in range(col_sums.shape[-1]) if col_sums[0,i] > c_sums[0,i]]


# needs to be called with initial swap already made
def ring_swap(m, voids, max_depth, c_sums, neighbors):
	if max_depth <= 0:
		return None
	# find out which columns have too many units
	bad_cols = get_imbalanced_cols(m, c_sums)
	# should only have a single column with too many units at any time
	if len(bad_cols) != 1:
		raise ValueError("Matrix too imbalanced")
	bad_col = bad_cols[0]
	# calculate all possible swaps that balance one of them
	swaps = []
	for r in range(m.shape[0]):
		for c in range(m.shape[1]):
			if c not in voids[r] and c != bad_col and m[r, bad_col] > 0:
				swaps.append(((r,c),(r,bad_col)))

	for inc, dec in swaps:
		# make swap
		m[inc] += 1
		m[dec] -= 1
		# if not balanced, recursive call with depth-1
		if len(get_imbalanced_cols(m, c_sums)) > 0:
			ring_swap(m, voids, max_depth - 1, c_sums, neighbors)
		else:
			neighbors.add(tuple(m.copy().A1))
		# undo swap
		m[inc] -= 1
		m[dec] += 1


class GibbsSampler(object):
	def __init__(self, initial_state, policy):
		self._current_state = initial_state
		self._policy = policy
		self.name = "Gibbs"

	def sample(self, num_burn=0):
		for _ in range(num_burn + 1):
			candidate, num_neighbors = generate_neighbor_ring_swap(self._current_state)
			_, candidate_neighbors = generate_neighbor_ring_swap(candidate)
			reach_ratio = candidate.get_reach_probability(self._policy) / self._current_state.get_reach_probability(self._policy)
			neighbor_ratio = num_neighbors / candidate_neighbors
			transition_probability = min(1., reach_ratio * neighbor_ratio)
			if np.random.rand() < transition_probability:
				self._current_state = candidate
		return self._current_state

	def generate_samples(self, num_samples, player=0):
		self._samples = []
		for _ in range(num_samples + 1):
			self._samples.append(util.expected_value(self.sample(), self._policy)[player])
		return self._samples

	def mc_estimate(self, num_burn, eval_every):
		if len(self._samples) <= 0:
			raise ValueError("Must call gen_samples first.")
		i = 0
		samples_used = []
		estimates = []
		while i < len(self._samples):
			if i % num_burn == 0:
				samples_used.append(self._samples[i])
			if i % eval_every == 0:
				estimates.append(np.mean(samples_used))
			i += 1
		return estimates