import os
from random import randint
import uuid
import numpy as np
from tqdm import tqdm
import torch
import pdb
from src.data_gen.tasks import get_task_sampler
from src.eval import validate_clf

import wandb


def train_step(model, xs, ys, optimizer, loss_func, max_grad_norm = None):
	optimizer.zero_grad()
	output = model(xs)

	loss = loss_func(output, ys)
	loss.backward()
	
	if max_grad_norm is not None:   
			torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

	optimizer.step()
	return loss.detach().item(), output.detach()




def sample_seeds(total_seeds, count):
	seeds = set()
	while len(seeds) < count:
		seeds.add(randint(0, total_seeds - 1))
	return seeds



def train(model, args):

	task = args.task
	device = torch.device("cuda:{}".format(args.gpu))

	optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

	starting_step = 0
	state_path = os.path.join(args.out_dir, "state.pt")
	if os.path.exists(state_path):
		state = torch.load(state_path)
		model.load_state_dict(state["model_state_dict"])
		optimizer.load_state_dict(state["optimizer_state_dict"])
		starting_step = state["train_step"]
		

	length = args.length
	bsize = args.batch_size

	task_sampler = get_task_sampler(
	args.task,
	length,
	bsize,
	data_size=args.data_size,
	**args.task_kwargs,
	)

	val_sampler = get_task_sampler(
	args.task,
	length,
	bsize,
	data_size=0,
	**args.task_kwargs,
	)


	if length >= 50 and args.task in ['index', 'string_equality']:
		print('Using Curriculum for length')
		cur_flag= True
		half_steps = args.train_steps // 2
	else:
		cur_flag = False
	

	print(f"Starting training for {args.task} task")

	pbar = tqdm(range(starting_step, args.train_steps))


	for i in pbar:
		
		task = task_sampler()

		if cur_flag:
			if i < half_steps:
				cur_length = np.random.choice([7, length//2])
				cur_length = cur_length * 2
			else:
				cur_length = length

			xs, ys = task.sample_data(length = cur_length)
		
		else:
			xs, ys = task.sample_data()
		
	

		loss_func = task.get_training_metric()

		loss, output = train_step(model, xs.to(device), ys.to(device), optimizer, loss_func)
		

	
		task_metric = task.get_metric()
		train_acc = task_metric(output, ys.to(device)).mean().item()
		null_pred = torch.zeros_like(ys) + 1
		null_acc = task_metric(null_pred, ys).mean().item()

		if i % args.val_every_steps == 0:
			val_task = val_sampler()
			val_acc = validate_clf(model, val_task, args, val_examples=500)
			print(f"Validation Accuracy after {i} steps: {val_acc}")
			
			if args.wandb:
				wandb.log(
					{
						"eval/val_acc": val_acc,
					},
					step=i,
				)
			
			if val_acc > 0.9999:
				break
		
		if args.wandb:
			if i % args.log_every_steps == 0:
				# init_distance = model_dist(curr_model= model, init_model= init_model, weight_only=True)
				wandb.log(
					{
						"train_acc": train_acc,
						"null_acc": null_acc,
						"overall_loss": loss,

					},
					step=i,
				)
				

		pbar.set_description(f"loss {loss}")
		if i % args.save_every_steps == 0:
			training_state = {
				"model_state_dict": model.state_dict(),
				"optimizer_state_dict": optimizer.state_dict(),
				"train_step": i,
			}
			torch.save(training_state, state_path)


		if (
			args.keep_every_steps > 0
			and i % args.keep_every_steps == 0
			and i > 0
		):
			torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt"))
