import os
import yaml

import torch
from torch.utils.tensorboard import SummaryWriter
import inspect

from ..utils.reproducibility_utils import seed_everything
from ..utils.data_loading import load_data
from ..utils.transformer_utils import configure_optimizers, WarmupCosineSchedule, loss_tracker
from ..utils import architectures

# fine tuning of pretrained models on downstream task
def train(args):
	seed = args["experiment"]["seed"]
	torch.set_printoptions(sci_mode=False, precision=2)
	seed_everything(seed)
		
	path_to_repo = os.getcwd()
	path_to_experiment = os.path.join(path_to_repo, args["experiment"]["folder_path"], args["experiment"]["name"])
	os.makedirs(os.path.join(path_to_experiment, "model_checkpoints")) 
	os.makedirs(os.path.join(path_to_experiment, "training_logs")) 

	if args["experiment"]["tensorboard_logging"]:
		writer = SummaryWriter(os.path.join(path_to_experiment, "training_logs"))
		
	device = args["experiment"]["device"]

	# data loading (data as csvs in a file, only load and scale them, add threshold 0.83 function in pipeline (below 1 kmh sensitivity issues))
	train_dataloader, validation_dataloader, _, _ = load_data(train_path = args["data"]["train_path"],
														train_ratio = args["data"]["train_ratio"],
														batchsize = args["data"]["batchsize"],
														columns_to_standardize = args["data"]["columns_to_standardize"],
														columns_to_drop = args["data"]["columns_to_drop"],
														sequence_length = args["data"]["sequence_length"],
														threshold = args["data"]["threshold"],
														threshold_value = args["data"]["threshold_value"],
														threshold_column = args["data"]["threshold_column"],
														target = args["data"]["target"],
														scaling = args["data"]["scaling"],
														seed = seed)
	
	path_to_pretrain_conf = os.path.join(path_to_repo, args["pretrained_backbone"]["pretrain_config_path"])
	pt_config = yaml.load(open(path_to_pretrain_conf, "r"), Loader=yaml.FullLoader)

	model_params = pt_config["backbone"]["model_params"]
	model_params["seed"] = seed
	model = architectures.__dict__[pt_config["backbone"]["model_name"]](**model_params).to(device)
	model.output_net = architectures.__dict__[pt_config["head"]["model_name"]](**pt_config["head"]["model_params"]).to(device)

	if args["pretrained_backbone"]["pretrained_model_path"] != None:
		checkpoint = torch.load(args["pretrained_backbone"]["pretrained_model_path"], map_location=torch.device(device)) 
		model.load_state_dict(checkpoint['model_state_dict'])

	model.output_net = architectures.__dict__[args["head"]["model_name"]](**args["head"]["model_params"]).to(device)

	pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
	print("trainable model params", pytorch_total_params)

	model.train()

	num_epochs = args["optimization"]["epochs"]

	total_training_steps = len(train_dataloader) * num_epochs
	optimizer = configure_optimizers(model, learning_rate=args["optimization"]["peak_lr"], weight_decay=args["optimization"]["weight_decay"], betas = (args["optimization"]["beta_1"], args["optimization"]["beta_2"]))
	scheduler = WarmupCosineSchedule(optimizer, total_training_steps, total_training_steps // args["optimization"]["warmup_div"], args["optimization"]["initial_lr"], args["optimization"]["min_lr"], args["optimization"]["peak_lr"])

	train_loss_tracker = loss_tracker({args["data"]["target"]: 0}, ft = True)
	val_loss_tracker = loss_tracker({args["data"]["target"]: 0}, ft = True)

	criterion = torch.nn.MSELoss()

	freeze_encoder = args["optimization"]["freeze"]

	if freeze_encoder == True:
		for name, child in model.named_children():
			for param in child.parameters():
				if name == "output_net":
					param.requires_grad = True
				else:
					param.requires_grad = False

	else:
		for name, child in model.named_children():
			for param in child.parameters():
				param.requires_grad = True

	if freeze_encoder == True:
		model.eval()

	lastepoch = 0
	for epoch in range(num_epochs):
		epoch = epoch + lastepoch + 1
		running_loss = 0

		print("Epoch: ", epoch)
		print("Training", end='   | ')

		if freeze_encoder == True:
			model.output_net.train()
		else:
			model.train()

		for step, (x, y) in enumerate(train_dataloader):
			X, y = x.to(device), y.to(device)

			optimizer.zero_grad()
			lr = scheduler.step()

			pred = model(X)[:,-1]
			train_loss = criterion(pred, y)

			if loss_tracker != None:
				train_loss_tracker.track(args["data"]["target"], train_loss.item())

			train_loss.backward()

			if scheduler._step > scheduler.warmup_steps:
				torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args["optimization"]["grad_clipping"])

			optimizer.step()
			train_loss_tracker.step()

		train_loss_tracker.update()

		with torch.no_grad():
			print("Validation", end=' | ')
			if freeze_encoder == True:
				model.output_net.eval()
			else:
				model.eval()

			for x, y in validation_dataloader:
				val_loss = 0
				x, y = x.to(device), y.to(device)
				y_hat = model(x)[:,-1]
				val_loss = criterion(y_hat, y)
				if loss_tracker != None:
					val_loss_tracker.track(args["data"]["target"], val_loss)
				val_loss_tracker.step()

			val_loss_tracker.update()

		if args["experiment"]["tensorboard_logging"]:
			for key in train_loss_tracker.get_losses().keys():
				writer.add_scalars(f'Loss_{key}', {"train": train_loss_tracker.get_losses()[key][-1],
									"val": val_loss_tracker.get_losses()[key][-1]}, 
									epoch)
			writer.flush()

		if epoch % args["experiment"]["model_saving_mod"] == 0:
			torch.save({
					'epoch': epoch,
					'model_state_dict': model.state_dict(),
					'optimizer_state_dict': optimizer.state_dict(),
					'loss': train_loss
					}, os.path.join(path_to_experiment, "model_checkpoints",f"pretrain_{epoch}_epochs"))
	
	if args["experiment"]["tensorboard_logging"]:
		writer.close()

