import os

import torch
from torch.utils.tensorboard import SummaryWriter
import inspect

from ..utils.reproducibility_utils import seed_everything
from ..utils.data_loading import load_data
from ..utils.transformer_utils import configure_optimizers, WarmupCosineSchedule, loss_tracker
from ..utils.masking import bias_mask, mean_mask, performance_degradation_mask, tst_noise_masking, MaskedMSELoss
from ..utils import architectures

def pretrain(args):

	# seeding
	seed = args["experiment"]["seed"]
	torch.set_printoptions(sci_mode=False, precision=2)
	seed_everything(seed)
		
	path_to_repo = os.getcwd()
	path_to_experiment = os.path.join(path_to_repo, args["experiment"]["folder_path"], args["experiment"]["name"])
	os.makedirs(os.path.join(path_to_experiment, "model_checkpoints")) 
	os.makedirs(os.path.join(path_to_experiment, "training_logs")) 

	# setup tensorboard writer
	if args["experiment"]["tensorboard_logging"]:
		writer = SummaryWriter(os.path.join(path_to_experiment, "training_logs"))
		
	# set device
	device = args["experiment"]["device"]

	train_dataloader, validation_dataloader, _, _ = load_data(train_path = args["data"]["train_path"],
														train_ratio = args["data"]["train_ratio"],
														batchsize = args["data"]["batchsize"],
														columns_to_standardize = args["data"]["columns_to_standardize"],
														columns_to_drop = args["data"]["columns_to_drop"],
														sequence_length = args["data"]["sequence_length"],
														threshold = args["data"]["threshold"],
														threshold_value = args["data"]["threshold_value"],
														threshold_column = args["data"]["threshold_column"],
														target = args["data"]["target"],
														scaling = args["data"]["scaling"],
														seed = seed)

	model_params = args["backbone"]["model_params"]
	model_params["seed"] = seed
	model = architectures.__dict__[args["backbone"]["model_name"]](**model_params).to(device)

	model.output_net = architectures.__dict__[args["head"]["model_name"]](**args["head"]["model_params"]).to(device)

	lastepoch = 0
	
	# checkpoint loading
	if False:
		checkpoint = torch.load("YOUR PATH")
		model.load_state_dict(checkpoint['model_state_dict'])
		optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		lastepoch = checkpoint['epoch']
		loss = checkpoint['loss']

		model.train()
		model.to(device)

	# train loop
	num_epochs = args["optimization"]["epochs"]
	total_training_steps = len(train_dataloader) * num_epochs

	optimizer = configure_optimizers(model, learning_rate=args["optimization"]["peak_lr"], weight_decay=args["optimization"]["weight_decay"], pe_available = True)
	scheduler = WarmupCosineSchedule(optimizer, total_training_steps, total_training_steps // args["optimization"]["warmup_div"], args["optimization"]["initial_lr"], args["optimization"]["min_lr"], args["optimization"]["peak_lr"])

	masks = {"bias": bias_mask(device, lower_bound=args["masking"]["bias_bounds"][0], upper_bound = args["masking"]["bias_bounds"][1]), 
            "mean": mean_mask(device),
            "noise": performance_degradation_mask(device, noise_level=args["masking"]["noise_level"])}

	masks = {k:masks[k] for k in args["masking"]["masks"] if k in masks}

	train_loss_tracker = loss_tracker(masks)
	val_loss_tracker = loss_tracker(masks)

	multi_task_weights = args["masking"]["multi_task_weights"]

	criterion = MaskedMSELoss()

	def infere(model, masks, X, loss_tracker = None):
		loss = 0
		temp_loss = 0
		
		X_split = list(torch.tensor_split(X, len(masks), dim=0))

		for idx, (mask_name, mask) in enumerate(masks.items()):
			X_m = X_split[idx].to(device)

			# define what to mask
			binary_mask = tst_noise_masking(X_m, device = device, r=args["masking"]["masking_ratio"], lm = args["masking"]["mean_mask_length"])
			# apply mask
			X_masked = mask.apply_mask(X_m, binary_mask)
			pred = model(X_masked)[0]
			pred = pred.reshape((X_m.size(0),args["data"]["sequence_length"],X.size(2)))
			temp_loss = criterion(pred, X_m, binary_mask.bool())

			if loss_tracker != None:
				loss_tracker.track(mask_name, temp_loss.detach().item())

			loss += multi_task_weights[idx] * temp_loss

		if loss_tracker != None:
				loss_tracker.track("combined", loss.detach().item())
		return loss
	
	for epoch in range(num_epochs):
		epoch = epoch + lastepoch + 1
		print("Epoch: ", epoch)
		print("Training", end='   | ')

		model.train()
		for X in train_dataloader:
				
			optimizer.zero_grad()
			lr = scheduler.step()

			train_loss = infere(model, masks, X, train_loss_tracker)
			train_loss.backward()

			if scheduler._step > scheduler.warmup_steps:
					torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args["optimization"]["grad_clipping"])

			optimizer.step()
			train_loss_tracker.step()

		train_loss_tracker.update()

		# validation
		with torch.no_grad():
			print("Validation", end=' | ')
			model.eval()
			for X in validation_dataloader:

					val_loss = infere(model, masks, X, val_loss_tracker)
					val_loss_tracker.step()

			val_loss_tracker.update()

		if args["experiment"]["tensorboard_logging"]:
			for key in train_loss_tracker.get_losses().keys():
				writer.add_scalars(f'Loss_{key}', {"train": train_loss_tracker.get_losses()[key][-1],
									"val": val_loss_tracker.get_losses()[key][-1]}, 
									epoch)
			writer.flush()

		if epoch % args["experiment"]["model_saving_mod"] == 0:
			torch.save({
					'epoch': epoch,
					'model_state_dict': model.state_dict(),
					'optimizer_state_dict': optimizer.state_dict(),
					'loss': train_loss
					}, os.path.join(path_to_experiment, "model_checkpoints",f"pretrain_{epoch}_epochs"))
	
	if args["experiment"]["tensorboard_logging"]:
		writer.close()
