import torch
import numpy as np 

class Task:
	# Taken from https://github.com/satwik77/incontext-bool
	def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
		self.n_dims = n_dims
		self.b_size = batch_size
		self.pool_dict = pool_dict
		self.seeds = seeds
		assert pool_dict is None or seeds is None
		
class Parity(Task):
	# Taken from https://github.com/satwik77/incontext-bool
	def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, scale=1):
		# Approximate 35% of indices will be 1
		super(Parity, self).__init__(n_dims, batch_size, pool_dict, seeds)
		funcs = np.random.choice(2**n_dims, size = batch_size)
		all_subsets  = self.generate_subsets(n_dims)
		self.w_b = torch.zeros(size= (batch_size, n_dims, 1))
		for i in range(batch_size):
			self.w_b[i, all_subsets[funcs[i]]] = 1
		
	
	def sample_xs(self, n_points, i=None):
		# Input distribution is uniform over {-1, 1}^n_dims
		xt = torch.randint(0, 2, (n_points, self.n_dims), dtype= torch.float)

		return xt

	def generate_subsets(self, n):
		subsets = []
		for i in range(2**n):
			subset = [j for j in range(n) if (i & 1 << j)]
			subsets.append(subset)
		return subsets
	
	def evaluate(self, xt, i):
		# Output \in {-1, 1}
		w_b = self.w_b[i].to(xt.device)
		yt = (xt @ w_b).squeeze() % 2
		return yt
	
class Conjunction(Task):
	# Taken from https://github.com/satwik77/incontext-bool
	def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, scale=1):
		super(Conjunction, self).__init__(n_dims, batch_size, pool_dict, seeds)
		# self.w_b = torch.randint(0, 2, (self.b_size, self.n_dims, 1))
		k = int(n_dims/3)
		self.w_b = torch.tensor(np.random.choice([0, 1, -1], size=(self.b_size, self.n_dims, 1), p=[0.7, 0.15, 0.15]), dtype=torch.float)
		self.kw = torch.norm(self.w_b, p=1, dim=1) - 1
	
	def sample_xs(self, n_points, i=None):
		xs_b = torch.randint(0, 2, (n_points, self.n_dims), dtype= torch.float)*2-1

		wb, k = self.w_b[i], self.kw[i]            
		pidx = [j for j in range(self.n_dims) if wb[j] == 1.0]
		nidx = [j for j in range(self.n_dims) if wb[j] == -1.0]
		for j in range(n_points):
			if np.random.choice([0, 1], p=[0.7, 0.3]):
				xs_b[j, pidx] = +1.0
				xs_b[j, nidx] = -1.0
				assert (xs_b[j, :] @ wb).squeeze() >= k

		xt = (xs_b + 1) / 2
		return xt

	def evaluate(self, xt, i):
		xs_b = xt * 2 - 1
		w_b = self.w_b[i].to(xs_b.device)
		# print(w_b.flatten().to(int).tolist())
		ys_b = (xs_b @ w_b).squeeze() - self.kw[i]
		return (ys_b.sign() + 1) // 2
	
concept_classes = {
    "parity": Parity,
    "conjunction": Conjunction
}