import os
import time
import torch

import sys; sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from trainer_utils import *

class TrainerRegressor(Trainer):
	def __init__(self, args, cfg, model, optimizer, loss_fn):
		super().__init__(args, cfg, model, optimizer, loss_fn)
		self.args = args
		self.cfg = cfg
		
		self.batch_size = cfg.exp.batch_size
		
	def _train_one_epoch(self):
		running_loss = 0.0
		torch.cuda.reset_peak_memory_stats(self.args.device)
		torch.cuda.synchronize()
		start_time = time.time()
		for idx, data in enumerate(self.train_loader):
			# Every data instance is an input + label pair
			smiles, targets = data
			targets = targets.clone().detach().to(self.device)

			# zero the parameter gradients (otherwise they are accumulated)
			self.optimizer.zero_grad()

			# Use (for the "model we brought" version):
			embedding_output_tuple = self.model.extract_embeddings(smiles)
			# Select the *third* element, which is smiles_embeddings
			#embeddings = embedding_output_tuple[2].to(self.device)
			embeddings = embedding_output_tuple.to(self.device)
			outputs = self.model.net(embeddings).squeeze()

			# Compute the loss and its gradients
			if self.args.task == 'classification':
				targets = targets.float()  # BCEWithLogitsLoss expects float
				loss = self.loss_fn(outputs, targets)
			else:
				loss = self.loss_fn(outputs, targets.squeeze(1))
			loss.backward()

			# Adjust learning weights
			self.optimizer.step()

			# print statistics
			running_loss += loss.item()
		torch.cuda.synchronize()
		epoch_time = time.time() - start_time
		peak_memory = torch.cuda.max_memory_allocated(self.args.device) / (1024 ** 2)  # MB
		print(epoch_time, peak_memory)
 
		return running_loss / len(self.train_loader)

	def _validate_one_epoch(self, data_loader, model=None, prefix=None):
		data_targets = []
		data_preds = []
		running_loss = 0.0

		model = self.model if model is None else model
		torch.cuda.reset_peak_memory_stats('cuda')
		torch.cuda.synchronize()
		start_time = time.time()
		with torch.no_grad():
			for idx, data in enumerate(data_loader):
				# Every data instance is an input + label pair
				smiles, targets = data
				# targets = targets.clone().detach().to(self.device)
				targets = targets.clone().detach().to(self.device).view(-1)

				# Make predictions for this batch
				embedding_output_tuple = self.model.extract_embeddings(smiles)
				# Select the *third* element, which is smiles_embeddings
				# embeddings = embedding_output_tuple[2].to(self.device)
				embeddings = embedding_output_tuple.to(self.device)
				# predictions = model.net(embeddings).squeeze()
				predictions = model.net(embeddings).view(-1)
				
				if self.args.task == 'classification':
					loss = self.loss_fn(predictions, targets)
					probs = torch.sigmoid(predictions)
					data_preds.append(probs)
				else:
					loss = self.loss_fn(predictions, targets)
					data_preds.append(predictions)

				data_targets.append(targets)

				# print statistics
				running_loss += loss.item()

		
		# Put together predictions and labels from batches
		preds = torch.cat(data_preds, dim=0).cpu().numpy()
		tgts = torch.cat(data_targets, dim=0).cpu().numpy()
		print(prefix, (tgts==0).sum())
		torch.cuda.synchronize()
		total_time = time.time() - start_time
		avg_time_per_sample = total_time / len(preds)
		peak_memory = torch.cuda.max_memory_allocated('cuda') / (1024 ** 2)
 
		print(f"Inference Time per sample: {avg_time_per_sample:.4f} s")
		print(f"Peak GPU Memory: {peak_memory:.2f} MB")
		return preds, tgts, (running_loss / len(data_loader))
	
	

def define_model(args, cfg):
	if cfg.model.smi_ted_version == 'light':
		from smi_ted_light.load import load_smi_ted
	elif cfg.model.smi_ted_version == 'large':
		from smi_ted_large.load import load_smi_ted
	else:
		raise ValueError(f"Unsupported smi_ted_version: {cfg.model.smi_ted_version}")
	model = load_smi_ted(folder=cfg.model.path, ckpt_filename=cfg.model.ckpt_path, seed = args.seed)
	
	if hasattr(model, '_init_weights'):
		model.net.apply(model._init_weights)
	return model
 

