import os
import wandb
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch_geometric.loader import DataLoader
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import time
from tqdm import tqdm
from .util import *

from models.blt_feature import MultiAnchorFeaturePredictor


def define_model(cfg, input_dim, num_tasks=1):
	if cfg.model.model_type == 'multi_anchor_feature':
		model = MultiAnchorFeaturePredictor(cfg, input_dim, num_tasks)
	return model

def build_memory_bank_feature(args, model, features, smiles, batch_size):
	model.eval()
	all_embeddings = []
	all_smiles = []

	features_loader = DataLoader(features, batch_size=batch_size, shuffle=False)
	smiles_loader = DataLoader(smiles, batch_size=batch_size, shuffle=False)
	
	with torch.no_grad():
		for features_batch, smiles_batch in zip(features_loader, smiles_loader):
			features_batch = features_batch.to(torch.float32)
			features_batch = features_batch.to(args.device)
			emb = model.extract_embedding(features_batch)
			all_embeddings.append(emb)
			all_smiles += smiles_batch

	memory_bank = torch.cat(all_embeddings, dim=0)
	return memory_bank, all_smiles


def train_model_gnn_with_memory(args, cfg, dataset, model, transducer):
	device = args.device
	optimizer = optim.Adam(model.parameters(), lr=cfg.exp.learning_rate)
	scheduler = ExponentialLR(optimizer, gamma=cfg.exp.scheduler.gamma)
	print(f"Using Adam optimizer with LR={cfg.exp.learning_rate} and ExponentialLR scheduler with gamma={cfg.exp.scheduler.gamma}")
	
	epoch_loss = []
	best_train_loss = float('inf')
	train_patience_counter = 0
	
	features, targets = dataset['train']['reps'], dataset['train']['targets']
	smiles =dataset['train']['smiles']
		
	features = torch.tensor(features, dtype=torch.float32, device=device)
	targets = torch.tensor(targets, dtype=torch.float32, device=device)

	tensor_dataset = TensorDataset(features, targets)

	train_loader = DataLoader(
		tensor_dataset, 
		batch_size=cfg.exp.batch_size, 
		shuffle=True,
		num_workers=0  
	)
	
	if cfg.transducer.sampling_strategy == 'temperature':
		initial_temperature = cfg.transducer.temperature.value
		anneal_epochs = min(cfg.annealing.anneal_epochs, cfg.exp.num_epochs)
		final_temperature = max(cfg.annealing.final_temperature, 1e-6) # Ensure positive
		initial_temp_calc = max(initial_temperature, 1e-6) # Ensure positive
  
	pbar = tqdm(range(cfg.exp.num_epochs), desc="Training Epochs")
	for epoch in pbar:
		model.train()
		torch.cuda.reset_peak_memory_stats(device)
		torch.cuda.synchronize()
		start_time = time.time()
		
		if (cfg.transducer.sampling_strategy == 'temperature') and (cfg.annealing.enabled == True):
			transducer.temperature = update_temperature(cfg, epoch, initial_temperature, anneal_epochs, final_temperature, initial_temp_calc)
   
		# Memory Bank Update Logic (Ensure transducer has train_embeddings attribute)
		if epoch > 0 and epoch % cfg.exp.memory_update_interval == 0:
			try:
				current_memory_bank, _ = build_memory_bank_feature(args, model, features, targets, batch_size=cfg.exp.batch_size)
				transducer.train_embs = current_memory_bank
			except Exception as build_e: print(f"Error updating memory bank: {build_e}")
		
		running_loss = 0
		for i, (features_batch, target_batch) in enumerate(train_loader):
			optimizer.zero_grad()
			features_batch = features_batch.to(args.device)
   
			query_embeddings = model.extract_embedding(features_batch)
			
			candidates_emb, anchor_weights, attn_mask, _, _ = transducer.choose_multiple_anchors(query_embeddings)
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			if args.task == 'classification':
				
				target_batch = target_batch.to(args.device).float()
				preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)
				loss = F.binary_cross_entropy_with_logits(preds_y.view(-1), target_batch.view(-1))
			else:
				target_batch = target_batch.to(args.device).float()
				preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)
				loss = F.l1_loss(preds_y, target_batch)

			loss.backward()			
			optimizer.step()
			running_loss += loss.item()
		
		avg_loss = running_loss / (i+1)
  
		torch.cuda.synchronize()
		epoch_time = time.time() - start_time
		peak_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 2)  # MB


		epoch_loss.append(avg_loss)
		current_lr = optimizer.param_groups[0]['lr']
		if args.wandb_log: wandb.log({'train_loss': avg_loss})
		pbar.set_postfix(loss=f'{avg_loss:.6f}', lr=f'{current_lr:.6f}')
		
		scheduler.step()
  
		if not np.isnan(avg_loss):
			if avg_loss < best_train_loss:
				best_train_loss = avg_loss
				train_patience_counter = 0
			else:
				train_patience_counter += 1
			
			if train_patience_counter >= cfg.exp.patience:
				break
		
		else:
			print(f"Warning: Skipping early stopping check for epoch {epoch+1} due to NaN average loss.")
		
		if args.wandb_log:
			wandb.log({
				'train_loss': epoch_loss,
				'epoch_time_sec': epoch_time,
				'peak_memory_MB': peak_memory
			})
			
	print("Training finished.")
	# Always save final model regardless of early stopping
	#torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))

	return model



def test_model(args, cfg, dataset, model, transducer, prefix=None):
	model.eval()
	torch.cuda.reset_peak_memory_stats(args.device)
	torch.cuda.synchronize()
	start_time = time.time()
	
	device = args.device
	features, targets = dataset[prefix]['reps'], dataset[prefix]['targets']
	smiles =dataset[prefix]['smiles']
		
	features = torch.tensor(features, dtype=torch.float32, device=device)
	targets = torch.tensor(targets, dtype=torch.float32)
	
	tensor_dataset = TensorDataset(features, targets)
	
	test_loader = DataLoader(
		tensor_dataset, 
		batch_size=cfg.exp.batch_size, 
		shuffle=False,
		num_workers=0  
	)
 
	preds = {'preds': [], 'gt': targets, 'query_smiles': smiles, 'anchor_idxs': [], 'anchor_smiles': []}
	
	with torch.no_grad():
		for features_batch, targets_batch in test_loader:
			features_batch = features_batch.to(args.device)
			query_embeddings = model.extract_embedding(features_batch)
			
			candidates_emb, anchor_weights, attn_mask, candidates_indices, candidates_smiles = transducer.choose_multiple_anchors(query_embeddings)
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)
   
			preds['preds'].append(preds_y)
			preds['anchor_idxs'].extend(candidates_indices)
			preds['anchor_smiles'].append(candidates_smiles)
	
	torch.cuda.synchronize()
	total_time = time.time() - start_time
	avg_time_per_sample = total_time / len(smiles)
	peak_memory = torch.cuda.max_memory_allocated(args.device) / (1024 ** 2)

	print(f"\n[{prefix}] Inference Time per sample: {avg_time_per_sample:.4f} s")
	print(f"[{prefix}] Peak GPU Memory: {peak_memory:.2f} MB")

	if args.wandb_log:
		wandb.log({f"{prefix}_inference_time_per_sample": avg_time_per_sample,
				   f"{prefix}_peak_memory_MB": peak_memory})
		
	# Compute overall statistics
	preds['preds'] = torch.cat(preds['preds'], dim=0).detach().cpu().numpy()
	preds['gt'] = np.array(preds['gt'])
	results = calculate_metrics(preds['gt'], preds['preds'], prefix=prefix, task=args.task)
	
	if args.wandb_log:
		wandb.summary.update(results)
		
	# Save results
	save_results(args, results, prefix)
	
	return preds
