# Example run commnand: python -m src.main -model mamba -task index -train_steps 1000 -gpu 0 -n_layer 2 -n_embd 1024 


import os
from random import randint
import uuid
import numpy as np
from tqdm import tqdm
import torch
import yaml


from src.args import build_parser
from src.remove_pt import delete_pt_files
from src.clf_models import build_clf
from src.data_gen.tasks import get_task_sampler
from src.trainer import train
from src.eval import validate_clf


import wandb


torch.backends.cudnn.benchmark = True



def main(args):

	if args.wandb:

		wandb.init(
			dir=args.out_dir,
			project=args.project,
			group = str(args.task),
			entity=args.entity,
			config=args.__dict__,
			notes=args.notes,
			name=args.name,                
			resume=True,
		)

	

	device = torch.device("cuda:{}".format(args.gpu))
 
	task_sampler = get_task_sampler(
	args.task,
	args.length,
	args.batch_size,
	data_size=0,
	**args.task_kwargs,
	)

	val_task = task_sampler()
	
	args.n_out= val_task.n_out
	args.n_words = val_task.n_words

	model = build_clf(args)
	model.to(device)
	model.device = device
	model.train()

	train(model, args)

	# Validate the trained model

	val_acc = validate_clf(model, val_task, args, val_examples=5000)
	print(f"Validation Accuracy: {val_acc}")

	if args.wandb:
		wandb.log({"eval/final_val_acc": val_acc})


	if args.delete:
		print('Deleting model (pt) files...')
		delete_pt_files(args.out_dir)

	




if __name__ == "__main__":
	parser = build_parser()
	args = parser.parse_args()
	
	print(f"Running with: {args}")

	run_id = str(uuid.uuid4())[:20]
	print(f"Run ID: {run_id}")
	
	args.name += '_' + args.model
	args.name += '_' + run_id[:8]

	args.out_dir = os.path.join(args.out_dir, args.task)
	out_dir = args.out_dir + '_' + args.model
	out_dir = os.path.join(args.out_dir, args.name)
	if not os.path.exists(out_dir):
		os.makedirs(out_dir)
	args.out_dir = out_dir


	with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
		yaml.dump(args.__dict__, yaml_file, default_flow_style=False)

	main(args)
