import os
import wandb
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader

import numpy as np

from tqdm import tqdm
from .util import *

class SmilesDataset(Dataset):
	def __init__(self, smiles, target):
		self.smiles = smiles
		self.target = target

	def __len__(self):
		return len(self.smiles)    
	
	def __getitem__(self, idx):
		smiles = self.smiles[idx]
		labels = self.target[idx]
		return smiles, labels
	
	
def build_memory_bank_gnn(args, model, smiles, batch_size):

	model.eval()
	all_embeddings = []
	all_smiles = []
	
	smiles_loader = DataLoader(smiles, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
	with torch.no_grad():
		for smiles_batch in smiles_loader:
			emb = model.encoder.extract_embeddings(smiles_batch).to(args.device)
			all_embeddings.append(emb)
			all_smiles += smiles_batch

	memory_bank = torch.cat(all_embeddings, dim=0)
	return memory_bank, all_smiles



def train_model_gnn_with_memory(args, cfg, smiles, targets, model, transducer):
	optimizer = optim.Adam(model.parameters(), lr=cfg.exp.learning_rate)
	scheduler = ExponentialLR(optimizer, gamma=cfg.exp.scheduler.gamma)
	print(f"Using Adam optimizer with LR={cfg.exp.learning_rate} and ExponentialLR scheduler with gamma={cfg.exp.scheduler.gamma}")
	
	epoch_loss = []
	best_train_loss = float('inf')
	train_patience_counter = 0
	
	train_dataset = SmilesDataset(smiles, targets)
	train_loader = DataLoader(
							train_dataset, 
							batch_size=cfg.exp.batch_size, 
							shuffle=True,
							num_workers=0,
							worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id)
	)
	
	if cfg.transducer.sampling_strategy == 'temperature':
		initial_temperature = cfg.transducer.temperature.value
		anneal_epochs = min(cfg.annealing.anneal_epochs, cfg.exp.num_epochs)
		final_temperature = max(cfg.annealing.final_temperature, 1e-6) # Ensure positive
		initial_temp_calc = max(initial_temperature, 1e-6) # Ensure positive
  
	pbar = tqdm(range(cfg.exp.num_epochs), desc="Training Epochs")
	for epoch in pbar:
		model.train()

		torch.cuda.reset_peak_memory_stats(args.device)
		torch.cuda.synchronize()
		start_time = time.time()
		
		if (cfg.transducer.sampling_strategy == 'temperature') and (cfg.annealing.enabled == True):
			transducer.temperature = update_temperature(cfg, epoch, initial_temperature, anneal_epochs, final_temperature, initial_temp_calc)
   
		# Memory Bank Update Logic (Ensure transducer has train_embeddings attribute)
		if epoch > 0 and epoch % cfg.exp.memory_update_interval == 0:
			try:
				current_memory_bank, _ = build_memory_bank_gnn(args, model, smiles, batch_size=cfg.exp.batch_size)
				transducer.train_embs = current_memory_bank
			except Exception as build_e: print(f"Error updating memory bank: {build_e}")
		
		running_loss = 0
		for i, (smi, target) in enumerate(train_loader):
			optimizer.zero_grad()
			query_embeddings = model.encoder.extract_embeddings(smi)
			candidates_emb, anchor_weights, attn_mask, _, _ = transducer.choose_multiple_anchors(query_embeddings, return_smiles=(epoch+1 == cfg.exp.num_epochs))
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			gt_y = target.to(args.device).float()  # BCE는 float 필요

			if args.task == 'classification':
				loss = F.binary_cross_entropy_with_logits(preds_y.view(-1), gt_y.view(-1))
			else:
				loss = F.l1_loss(preds_y, gt_y)

			loss.backward()
			optimizer.step()
			running_loss += loss.item()

		torch.cuda.synchronize()
		epoch_time = time.time() - start_time
		peak_memory = torch.cuda.max_memory_allocated(args.device) / (1024 ** 2)  # MB

		avg_loss = running_loss / (i+1)
		
		epoch_loss.append(avg_loss)
		current_lr = optimizer.param_groups[0]['lr']
		if args.wandb_log: wandb.log({"train_loss": avg_loss})
		pbar.set_postfix(loss=f'{avg_loss:.6f}', lr=f'{current_lr:.6f}')
		
		scheduler.step()
  
		if not np.isnan(avg_loss):
			if avg_loss < best_train_loss:
				best_train_loss = avg_loss
				train_patience_counter = 0
			else:
				train_patience_counter += 1
			
			if train_patience_counter >= cfg.exp.patience:
				break
		
		else:
			print(f"Warning: Skipping early stopping check for epoch {epoch+1} due to NaN average loss.")
		if args.wandb_log:
			wandb.log({
				'train_loss': epoch_loss,
				'epoch_time_sec': epoch_time,
				'peak_memory_MB': peak_memory
			})
			
	print("Training finished.")
	# Always save final model regardless of early stopping
	#torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))

	return model





import time
import torch

def test_model(args, cfg, smiles, targets, model, transducer, prefix=None):
	model.eval()

	torch.cuda.reset_peak_memory_stats(args.device)
	torch.cuda.synchronize()
	start_time = time.time()

	test_dataset = SmilesDataset(smiles, targets)
	test_loader = DataLoader(
		test_dataset, 
		batch_size=cfg.exp.batch_size, 
		shuffle=False,
		num_workers=0,
		worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id)
	)

	preds = {'preds': [], 'gt': targets, 'query_smiles': smiles, 'anchor_idxs': [], 'anchor_smiles': []}

	with torch.no_grad():
		for smi, target in test_loader:
			query_embeddings = model.encoder.extract_embeddings(smi)
			candidates_emb, anchor_weights, attn_mask, candidates_indices, candidates_smiles = transducer.choose_multiple_anchors(query_embeddings)
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			if args.task == 'classification':
				preds_prob = torch.sigmoid(preds_y)
				preds['preds'].append(preds_prob)
			else:
				preds['preds'].append(preds_y)

			preds['anchor_idxs'].extend(candidates_indices)
			preds['anchor_smiles'].append(candidates_smiles)

	torch.cuda.synchronize()
	total_time = time.time() - start_time
	avg_time_per_sample = total_time / len(smiles)
	peak_memory = torch.cuda.max_memory_allocated(args.device) / (1024 ** 2)

	print(f"\n[{prefix}] Inference Time per sample: {avg_time_per_sample:.4f} s")
	print(f"[{prefix}] Peak GPU Memory: {peak_memory:.2f} MB")

	if args.wandb_log:
		wandb.log({
			f"{prefix}_inference_time_per_sample": avg_time_per_sample,
			f"{prefix}_peak_memory_MB": peak_memory
		})

	preds['preds'] = torch.cat(preds['preds'], dim=0).detach().cpu().numpy()
	preds['gt'] = np.array(preds['gt'])
	results = calculate_metrics(preds['gt'], preds['preds'], prefix=prefix, task=args.task)

	if args.wandb_log:
		wandb.summary.update(results)

	save_results(args, results, prefix)

	return preds
