import pandas as pd
import pickle 
import os
import numpy as np
import torch
from time import time
import pdb

def squared_error(ys_pred, ys):
	return (ys - ys_pred).square()


def mean_squared_error(ys_pred, ys):
	return (ys - ys_pred).square().mean()


def accuracy(ys_pred, ys):
	return (ys == ys_pred.sign()).float()


def multiclass_accuracy(ys_pred, ys):

	if ys_pred.dim() == 2:
		predictions = torch.argmax(ys_pred, dim=1)
	else:
		predictions = ys_pred

	correct_predictions = (predictions == ys).float()  # Convert boolean tensor to float
	accuracy = correct_predictions.mean()  # Calculate mean to get the accuracy percentage

	return accuracy


sigmoid = torch.nn.Sigmoid()
bce_loss = torch.nn.BCELoss()

cross_entropy_loss = torch.nn.CrossEntropyLoss()

def cross_entropy(ys_pred, ys):
	'''
	ys_pred: [-inf, inf]
	ys: {-1, 1}
	'''
	output = sigmoid(ys_pred)
	target = (ys + 1) / 2
	return bce_loss(output, target)


class CLFTask:
	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		self.length = length
		self.b_size = batch_size
		self.pool_dict = pool_dict
		self.seeds = seeds
		self.n_out = 1
		assert pool_dict is None or seeds is None

	def evaluate(self, xs):
		raise NotImplementedError

	@staticmethod
	def generate_pool_dict(n_dims, data_size):
		raise NotImplementedError

	@staticmethod
	def get_metric():
		raise NotImplementedError

	@staticmethod
	def get_training_metric():
		raise NotImplementedError




def get_task_sampler(
	task_name, length, batch_size, pool_dict=None, data_size=0, **kwargs
):
	task_names_to_classes = {
		'equality': EqualityTask,
		'equality_hard': EqualityHard,
		'string_equality': StringEquality, 
		'dyck2': Dyck2Task,
		'index' : IndexTask,
	}
	if task_name in task_names_to_classes:
		task_cls = task_names_to_classes[task_name]
		if data_size > 0:
			if pool_dict is not None:
				raise ValueError("Either pool_dict or data_size should be None.")
			pool_dict = task_cls.generate_pool_dict(length, data_size, **kwargs)
		return lambda **args: task_cls(length, batch_size, pool_dict, **args, **kwargs)
	else:
		print("Unknown task")
		raise NotImplementedError




class EqualityTask(CLFTask):
	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		super().__init__(length, batch_size, pool_dict=pool_dict, seeds=seeds)
		self.task = "Equality"
		self.n_words = 3

		if pool_dict is not None:
			assert 'data' in pool_dict
			indices = torch.randperm(len(pool_dict["data"]))[:batch_size]
			self.xs = pool_dict["data"][indices]
			self.ys = pool_dict["labels"][indices]
		
		else:
			self.xs = None
			self.ys = None

	def sample_data(self):

		if self.xs is None:
			
			# Create data using data generator function
			xs, ys = self.data_generator(self.b_size, self.length)

			return xs, ys
		
		
		else:
			return self.xs, self.ys




	@staticmethod
	def generate_pool_dict(length, data_size, **kwargs):
		
		start_time = time()

		# Create data using data generator function
		xs, ys = EqualityTask.data_generator(data_size, length)

		end_time = time()
		print('Time to generate pool dict: {:.2f} mins {:.2f} secs'.format((end_time-start_time)//60, (end_time-start_time)%60))

		return {"data": xs, "labels": ys}
	
	@staticmethod
	def data_generator(data_size, length):
		'''
			Generate data for the task.
			Output: xs (np.array), ys (np.array)
		'''

		assert length % 2 == 0
		xs = np.zeros((data_size, length+1), dtype=np.int32)
		ys = np.zeros((data_size), dtype=np.int32)

		for i in range(data_size):
			if np.random.rand() < 0.5:
				# Generate a positive example
				half = np.random.randint(0, 2, length // 2)
				xs[i] = np.concatenate((half, [2], half))
				ys[i] = 1
			else:
				# Generate a uniformly sampled vector
				half_1 = np.random.randint(0, 2, length // 2)
				half_2 = np.random.randint(0, 2, length // 2)
				xs[i] = np.concatenate((half_1, [2], half_2))
	
				ys[i] = int(np.all(half_1 == half_2))
		
		# Loss requires ys to be in {-1, 1}
		ys = 2 * ys - 1

		# Check if the ys are in {-1, 1}
		assert np.all(np.isin(ys, [-1, 1]))

		# Convert numpy arrays to torch tensors
		xs = torch.tensor(xs, dtype=torch.long)
		ys = torch.tensor(ys, dtype=torch.float)
		
		return xs, ys
	

	@staticmethod
	def get_metric():
		return accuracy


	@staticmethod
	def get_training_metric():
		return cross_entropy





class EqualityHard(CLFTask):
	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		super().__init__(length, batch_size, pool_dict=pool_dict, seeds=seeds)
		self.task = "Equality"
		self.n_words = 3

		if pool_dict is not None:
			assert 'data' in pool_dict
			indices = torch.randperm(len(pool_dict["data"]))[:batch_size]
			self.xs = pool_dict["data"][indices]
			self.ys = pool_dict["labels"][indices]
		else:
			self.xs = None
			self.ys = None

	def sample_data(self):
		
		if self.xs is None:
			# Create data using data generator function
			xs, ys = self.data_generator(self.b_size, self.length)

			return xs, ys
		
		else:
			return self.xs, self.ys



	@staticmethod
	def generate_pool_dict(length, data_size, **kwargs):
		
		start_time = time()

		# Create data using data generator function 
		xs, ys = EqualityHard.data_generator(data_size, length)

		end_time = time()
		print('Time to generate pool dict: {:.2f} mins {:.2f} secs'.format((end_time-start_time)//60, (end_time-start_time)%60))

		return {"data": xs, "labels": ys}
	

	@staticmethod
	def data_generator(data_size, length):
		'''
			Generate data for the task.
			Output: xs (np.array), ys (np.array)
		'''

		assert length % 2 == 0  # Ensure the length is even.

		xs = np.zeros((data_size, length+1), dtype=np.int32)
		ys = np.zeros((data_size), dtype=np.int32)

		for i in range(data_size):
			if np.random.rand() < 0.5:
				# Generate a positive example
				half = np.random.randint(0, 2, length // 2)
				xs[i] = np.concatenate((half, [2], half))
				ys[i] = 1
			else:
				# Create a half with a 0 and 1 bit swapped
				while True:
					half = np.random.randint(0, 2, length // 2)
					ones_idx = np.where(half == 1)[0]
					zeros_idx = np.where(half == 0)[0]
					half_copy = half.copy()
					if len(ones_idx) > 0 and len(zeros_idx) > 0:
						flip_idx = np.random.choice(ones_idx)
						half_copy[flip_idx] = 0
						flip_idx = np.random.choice(zeros_idx)
						half_copy[flip_idx] = 1
						break
					
				xs[i] = np.concatenate((half, [2], half_copy))
				ys[i] = 0
		
			
		# Loss requires ys to be in {-1, 1}
		ys = 2 * ys - 1

		# Check if the ys are in {-1, 1}
		assert np.all(np.isin(ys, [-1, 1]))

		# Convert numpy arrays to torch tensors
		xs = torch.tensor(xs, dtype=torch.long)
		ys = torch.tensor(ys, dtype=torch.float)
		
		return xs, ys

	@staticmethod
	def get_metric():
		return accuracy

	@staticmethod
	def get_training_metric():
		return cross_entropy
	




class Dyck2Task(CLFTask):
	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		super().__init__(length, batch_size, pool_dict=pool_dict, seeds=seeds)
		self.task = "Dyck2_depth2"
		self.n_words = 5  # For '(', ')', '[', ']'

		if pool_dict is not None:
			assert 'data' in pool_dict
			indices = torch.randperm(len(pool_dict["data"]))[:batch_size]
			self.xs = pool_dict["data"][indices]
			self.ys = pool_dict["labels"][indices]
		
		else:
			self.xs = None
			self.ys = None


	def sample_data(self):
		'''
		Samples data from the task.

		Returns:
		- xs (torch.Tensor): The input data.
		- ys (torch.Tensor): The target labels.
		'''
		if self.xs is None:
			# Create data using data generator function
			xs, ys = self.data_generator(self.b_size, self.length)

			return xs, ys
		
		else:
			return self.xs, self.ys

	@staticmethod
	def generate_pool_dict(length, data_size, **kwargs):
		'''
		Generates a pool dictionary containing data and labels.

		Args:
		- length (int): The length of the data sequence.
		- data_size (int): The size of the data pool.

		Returns:
		- pool_dict (dict): A dictionary containing data and labels.
		'''
		start_time = time()

		# Create data using data generator function
		xs, ys = Dyck2Task.data_generator(data_size, length)


		end_time = time()
		print('Time to generate pool dict: {:.2f} mins {:.2f} secs'.format((end_time-start_time)//60, (end_time-start_time)%60))

		return {"data": xs, "labels": ys}
	

	@staticmethod
	def data_generator(data_size, length):
		
		xs = np.zeros((data_size, length), dtype=np.int32)
		ys = np.zeros((data_size), dtype=np.int32)

		for i in range(data_size):
			if np.random.choice(2):
				# Generate a positive example with depth at most 2
				xs[i]  = Dyck2Task.generate_balanced_parentheses(length)
				ys[i] = 1  # Positive example
			else:
				# Generate a positive example and corrupt it to create a negative example
				xs[i]  = Dyck2Task.generate_balanced_parentheses(length)
				xs[i] = Dyck2Task.corrupt_parentheses(xs[i])
				ys[i] = 0  # Negative example

		# Convert numpy arrays to torch tensors
		xs = torch.tensor(xs, dtype=torch.long)
		ys = torch.tensor(ys, dtype=torch.float)

		ys = 2 * ys - 1  # Convert ys to {-1, 1}

		return xs, ys
	

	@staticmethod
	def generate_balanced_parentheses(length):
		""" Generate balanced parentheses strings with depth at most 2. """
		assert length % 2 == 0
		seq, stack, current_depth = [], [], 0
		gen_len = length - 2
		for _ in range(gen_len):
			if current_depth < 2:
				if stack and np.random.choice(2):
					seq.append(stack.pop())
					current_depth -= 1
				else:
					if np.random.choice(2):
						seq.append(0)
						stack.append(1)
					else:
						seq.append(2)
						stack.append(3)

					current_depth += 1
			else:
				seq.append(stack.pop())
				current_depth -= 1
		seq.extend(reversed(stack))

		if len(seq) == gen_len:
			if np.random.choice(2):
				seq.extend([0, 1])
			else:
				seq.extend([2, 3])

		assert len(seq) == length

		return np.array(seq)


	@staticmethod
	def corrupt_parentheses(seq):
		"""
		Corrupt a sequence of balanced parentheses by making multiple changes,
		ensuring the sequence becomes unbalanced or structurally incorrect, guaranteeing a label of 0.

		Args:
		seq (numpy array): Original sequence of balanced parentheses.
		
		Returns:
		numpy array: Corrupted sequence.
		"""
		num_corruptions = np.random.randint(1, max(2, len(seq) // 10))  # At least 1 corruption, up to 25% of the length
		corruption_indices = np.random.choice(len(seq), size=num_corruptions, replace=False)

		for idx in corruption_indices:
			current = seq[idx]
			# Choose a random corruption: switch bracket types, ensuring it's always a change
			if current in [0, 1]:  # '(', ')'
				# Change to a square bracket, ensuring different bracket type
				seq[idx] = np.random.choice([2, 3])
			else:  # '[', ']'
				# Change to a round bracket, ensuring different bracket type
				seq[idx] = np.random.choice([0, 1])

		return seq
	

	@staticmethod
	def get_metric():
		return accuracy


	@staticmethod
	def get_training_metric():
		return cross_entropy







class StringEquality(CLFTask):
	n_words = 3 + 1024
	word_dict = {'bos': 0, 'eos': 1, 'sep': 2}


	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		super().__init__(length, batch_size, pool_dict=pool_dict, seeds=seeds)
		self.task = "String_Equality"


		if pool_dict is not None:
			assert 'data' in pool_dict
			indices = torch.randperm(len(pool_dict["data"]))[:batch_size]
			self.xs = pool_dict["data"][indices]
			self.ys = pool_dict["labels"][indices]
		else:
			self.xs = None
			self.ys = None

	def sample_data(self, length=None):
		
		if self.xs is None:
			# Create data using data generator function
			if length is not None:
				xs, ys = self.data_generator(self.b_size, length)
			else:
				xs, ys = self.data_generator(self.b_size, self.length)

			return xs, ys
		
		else:
			return self.xs, self.ys



	@staticmethod
	def generate_pool_dict(length, data_size, **kwargs):
		
		start_time = time()

		# Create data using data generator function 
		xs, ys = StringEquality.data_generator(data_size, length)

		end_time = time()
		print('Time to generate pool dict: {:.2f} mins {:.2f} secs'.format((end_time-start_time)//60, (end_time-start_time)%60))

		return {"data": xs, "labels": ys}
	

	@staticmethod
	def data_generator(data_size, length):
		'''
			Generate data for the task.
			Output: xs (np.array), ys (np.array)
		'''

		assert length % 2 == 0  # Ensure the length is even.

		n_words = StringEquality.n_words
		word_dict = StringEquality.word_dict

		xs = np.zeros((data_size, length+2), dtype=np.int32)
		ys = np.zeros((data_size), dtype=np.int32)

		for i in range(data_size):
			if np.random.rand() < 0.5:
				# Generate a positive example
				half = np.random.randint(3, n_words, length // 2)
				xs[i] = np.concatenate((half, [word_dict['sep']], half, [word_dict['eos']]))
				ys[i] = 1
			else:
				# Create a half with a few bits flipped
				half = np.random.randint(3, n_words, length // 2)
				num_corruptions = np.random.randint(1, max(2, len(half) // 4))  # At least 1 corruption, up to 25% of the length
				corruption_indices = np.random.choice(len(half), size=num_corruptions, replace=False)
				# flip_idx = np.random.randint(0, length // 2)
				half_copy = half.copy()

				new_values = np.random.randint(3, n_words, num_corruptions)

				while np.any(new_values == half_copy[corruption_indices]):
					new_values = np.random.randint(3, n_words, num_corruptions)

				half_copy[corruption_indices] = new_values

				# half_copy[flip_idx] = 1 - half_copy[flip_idx]
				xs[i] = np.concatenate((half, [word_dict['sep']], half_copy, [word_dict['eos']]))
				ys[i] = 0
		
		# Loss requires ys to be in {-1, 1}
		ys = 2 * ys - 1

		# Check if the ys are in {-1, 1}
		assert np.all(np.isin(ys, [-1, 1]))

		# Convert numpy arrays to torch tensors
		xs = torch.tensor(xs, dtype=torch.long)
		ys = torch.tensor(ys, dtype=torch.float)
		
		return xs, ys

	@staticmethod
	def get_metric():
		return accuracy

	@staticmethod
	def get_training_metric():
		return cross_entropy
	






class IndexTask(CLFTask):
	
	def __init__(self, length, batch_size, pool_dict=None, seeds=None):
		super().__init__(length, batch_size, pool_dict=pool_dict, seeds=seeds)
		self.task = "Index"
		self.n_out = 64
		self.n_words = 1 + self.n_out + length
		self.word_dict = {'sep': self.n_out}     # 32 : Separator, 0-31 : tokens, 33 to 33 + length : indices


		self.xs = None
		self.ys = None

	def sample_data(self, length=None):
		
		if self.xs is None:
			# Create data using data generator function
			if length is not None:
				xs, ys = self.data_generator(self.b_size, length)
			else:
				xs, ys = self.data_generator(self.b_size, self.length)

			return xs, ys
		
		else:
			return self.xs, self.ys



	def data_generator(self, data_size, length):
		'''
			Generate data for the task.
			Output: xs (torch.tensor), ys (torch.tensor)
		'''

		n_words = self.n_words
		word_dict = self.word_dict
		n_out = self.n_out

		xs = np.zeros((data_size, length+2), dtype=np.int32)
		ys = np.zeros((data_size), dtype=np.int32)

		for i in range(data_size):
	
			string = np.random.randint(0, n_out, length)
			index = np.random.randint(0, length)
			token_id = index + n_out + 1

			xs[i] = np.concatenate((string, [word_dict['sep']], [token_id]))
			ys[i] = string[index]

		# Convert numpy arrays to torch tensors
		xs = torch.tensor(xs, dtype=torch.long)
		ys = torch.tensor(ys, dtype=torch.long)
		
		return xs, ys

	@staticmethod
	def get_metric():
		return multiclass_accuracy

	@staticmethod
	def get_training_metric():
		return cross_entropy_loss
	

