import argparse
import json

TASK_LIST = [
	"equality",
	"equality_hard",
	"string_equality",
	'dyck2',
	'index'
]


MODELS = ['san', 'lstm', 'mysan', 'retnet', 'dss', 'mamba']

def build_parser():
	parser = argparse.ArgumentParser(description='Run')

	# Miscellaneous
	parser.add_argument('-wandb', dest='wandb', action='store_true', help='Store wandb')
	parser.add_argument('-no-wandb', dest='wandb', action='store_false', help='Do not store wandb')
	parser.set_defaults(wandb=False)
	parser.add_argument('-gpu', type=int, default=0, help='Specify the gpu to use')
	parser.add_argument('-seed', type=int, default=1729, help='Default seed to set')

	parser.add_argument('-out_dir', type=str, default='./models/', help='outputs directory')
	parser.add_argument('-delete', dest='delete', action='store_true', help='delete model after run')
	parser.add_argument('-no-delete', dest='delete', action='store_false', help='Do not delete model after run')
	parser.set_defaults(delete=True)


	# Model
	parser.add_argument('-model', type=str, default='san', choices= MODELS,  help='Model Family')
	parser.add_argument('-n_positions', type=int, default=150, help='Maximum context length')
	parser.add_argument('-n_embd', type=int, default=128, help='embedding dimension')
	parser.add_argument('-n_layer', type=int, default=3, help='number of layers')
	parser.add_argument('-n_head', type=int, default=8, help='number of heads')
	parser.add_argument('-order', type=int, default=3, help='Order: For Hyena')
	parser.add_argument('-pos', dest='pos', action='store_true', help='pos encodings in the model')
	parser.add_argument('-no-pos', dest='pos', action='store_false', help='No pos encodings in the model')
	parser.set_defaults(pos=False)

	# Task
	parser.add_argument('-task', type=str, default='equality', choices= TASK_LIST,  help='Task')
	parser.add_argument('-length', type=int, default=10, help='Length of inputs')
	parser.add_argument('-task_kwargs', type=json.loads, default='{}', help='Task arguments')
	parser.add_argument('-data_size', type=int, default=0, help='size of training data')
	parser.add_argument('-val_every_steps', type=int, default=2500, help='Validation every how many steps')
	

	# Training
	parser.add_argument('-batch_size', type=int, default=64, help='Batch size')
	parser.add_argument('-learning_rate', type=float, default=0.0001, help='Learning rate')
	parser.add_argument('-train_steps', type=int, default=1000, help='number of train steps')
	
	parser.add_argument('-analyze', dest='analyze', action='store_true', help='analyze')
	parser.add_argument('-no-analyze', dest='analyze', action='store_false', help='Do not  analyze')
	parser.set_defaults(analyze=False)
	parser.add_argument('-save_every_steps', type=int, default=1000, help='how often to checkpoint')
	parser.add_argument('-keep_every_steps', type=int, default=100000, help='permanent checkpoints')

	
	# Wandb
	parser.add_argument('-project', type=str, default='Eq-arch', help='wandb project name')
	parser.add_argument('-entity', type=str, default='entity_name', help='wandb entity name')
	parser.add_argument('-notes', type=str, default='', help='wandb notes')
	parser.add_argument('-name', type=str, default='eq_test', help='run name')
	parser.add_argument('-log_every_steps', type=int, default=100, help='wandb log every how many steps')

	


	return parser