import os
import glob
import time
import warnings
import wandb
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch_geometric.loader import DataLoader

import numpy as np

from tqdm import tqdm
from .util import *

from baselines.pretrained_gnns.model import GNN_graphpred
from models.blt_graph import MultiAnchorGNNPredictor


def define_model(args, cfg, num_tasks=None):
	if cfg.model.model_type == 'multi_anchor_gnn':
		pretrained_encoder = GNN_graphpred(
			num_layer=cfg.model.gnn_num_layer,
			emb_dim=cfg.model.gnn_emb_dim,
			num_tasks=num_tasks,
			JK=cfg.model.gnn_JK,
			drop_ratio=cfg.model.gnn_drop_ratio,
			graph_pooling=cfg.model.gnn_graph_pooling,
			gnn_type=cfg.model.gnn_type
		)
		print(f"------------------------Loading pretrained encoder from: {cfg.model.encoder_path}------------------------")
		if 'supervised_contextpred' in cfg.model.encoder_path:
			pass
		else:
			path_pattern = glob.glob(
				os.path.join(
					cfg.model.encoder_path,
					args.dataset_name,
					args.dataset_split_type,
					args.prop_type,
					str(args.seed),
					'*',
					'ckpts',
					'final.pt'
				)
			)
			if not path_pattern:
				raise SystemExit(
					"No encoder checkpoint found matching the provided cfg.model.encoder_path pattern."
				)
			if len(path_pattern) > 1:
				path_pattern.sort(key=lambda p: (os.path.getmtime(p), p))
				selected_ckpt = path_pattern[-1]
				warnings.warn(
					f"Multiple encoder checkpoints found; using the most recent one: {selected_ckpt}",
					RuntimeWarning
				)
			else:
				selected_ckpt = path_pattern[0]
			cfg.model.encoder_path = selected_ckpt
		print(cfg.model.encoder_path)
		try:
			pretrained_encoder.from_pretrained(cfg.model.encoder_path)
			print("Loaded weights using 'from_pretrained' method.")

		except:
			state_dict = torch.load(cfg.model.encoder_path, map_location='cpu')
			# Handle nested state dicts
			if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict']
			elif 'gnn_model_state_dict' in state_dict: state_dict = state_dict['gnn_model_state_dict']
			# Clean keys ('module.' prefix)
			state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
			pretrained_encoder.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
			print("Loaded weights using 'load_state_dict'.")
		
		model = MultiAnchorGNNPredictor(
			pretrained_encoder=pretrained_encoder, # The loaded GNN object
			num_tasks=num_tasks,
			latent_dim=cfg.model.latent_dim, # Use hidden_layer_size from function args
			num_heads=cfg.model.num_heads, # Pass extracted num_heads
			use_anchor_weights=cfg.transducer.use_anchor_weights, # Pass from function args
			num_candidates=cfg.transducer.num_candidates # Pass from function args
		)
		print("Instantiated MultiAnchorGNNPredictor.")
	else:
		print(cfg.model.model_type)
		raise NotImplementedError('model is not implemented.')
	return model

def build_memory_bank_gnn(args, model, graphs, smiles, batch_size):

	model.eval()
	all_embeddings = []
	all_smiles = []
	
	garph_loader = DataLoader(graphs, batch_size=batch_size, shuffle=False)
	smiles_loader = DataLoader(smiles, batch_size=batch_size, shuffle=False)
	
	with torch.no_grad():
		for batch, smiles_batch in zip(garph_loader, smiles_loader):
			batch = batch.to(args.device)
			emb = model.extract_embedding(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
			all_embeddings.append(emb)
			all_smiles += smiles_batch

	memory_bank = torch.cat(all_embeddings, dim=0)
	return memory_bank, all_smiles


def train_model_gnn_with_memory(args, cfg, graphs, smiles, model, transducer):
	optimizer = optim.Adam(model.parameters(), lr=cfg.exp.learning_rate)
	scheduler = ExponentialLR(optimizer, gamma=cfg.exp.scheduler.gamma)
	print(f"Using Adam optimizer with LR={cfg.exp.learning_rate} and ExponentialLR scheduler with gamma={cfg.exp.scheduler.gamma}")
	
	epoch_loss = []
	best_train_loss = float('inf')
	train_patience_counter = 0
	
	train_loader = DataLoader(
		graphs, 
		batch_size=cfg.exp.batch_size, 
		shuffle=True,
		num_workers=0,
		worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id)
	)
	if cfg.transducer.sampling_strategy == 'temperature':
		initial_temperature = cfg.transducer.temperature.value
		anneal_epochs = min(cfg.annealing.anneal_epochs, cfg.exp.num_epochs)
		final_temperature = max(cfg.annealing.final_temperature, 1e-6) # Ensure positive
		initial_temp_calc = max(initial_temperature, 1e-6) # Ensure positive
  
	pbar = tqdm(range(cfg.exp.num_epochs), desc="Training Epochs")
	for epoch in pbar:
	 
		model.train()
		torch.cuda.reset_peak_memory_stats(args.device)
		torch.cuda.synchronize()
		start_time = time.time()
		
		if (cfg.transducer.sampling_strategy == 'temperature') and (cfg.annealing.enabled == True):
			transducer.temperature = update_temperature(cfg, epoch, initial_temperature, anneal_epochs, final_temperature, initial_temp_calc)
  
		if epoch > 0 and epoch % cfg.exp.memory_update_interval == 0:
			try:
				current_memory_bank, _ = build_memory_bank_gnn(args, model, graphs, smiles, batch_size=cfg.exp.batch_size)
				transducer.train_embs = current_memory_bank
			except Exception as build_e: print(f"Error updating memory bank: {build_e}")
		
		running_loss = 0
		for i, batch in enumerate(train_loader):
			optimizer.zero_grad()
			batch = batch.to(args.device)
			query_embeddings = model.extract_embedding(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
			
			candidates_emb, anchor_weights, attn_mask, _, _ = transducer.choose_multiple_anchors(query_embeddings)
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			if args.task == 'classification':
				gt_y = torch.tensor(batch.y, dtype=torch.float32, device=args.device).unsqueeze(-1)
				loss = F.binary_cross_entropy_with_logits(preds_y, gt_y.squeeze(-1))
			else:
				gt_y = torch.tensor(batch.y, dtype=torch.float32, device=args.device)
				loss = F.l1_loss(preds_y, gt_y)

			loss.backward()
			optimizer.step()
			running_loss += loss.item()

		torch.cuda.synchronize()
		epoch_time = time.time() - start_time
		peak_memory = torch.cuda.max_memory_allocated(args.device) / (1024 ** 2)  # MB

		avg_loss = running_loss / (i+1)
		epoch_loss.append(avg_loss)
		current_lr = optimizer.param_groups[0]['lr']
		if args.wandb_log: wandb.log({'train_loss': avg_loss})
		pbar.set_postfix(loss=f'{avg_loss:.6f}', lr=f'{current_lr:.6f}')
		scheduler.step()

		if args.wandb_log:
			wandb.log({
				'train_loss': epoch_loss,
				'epoch_time_sec': epoch_time,
				'peak_memory_MB': peak_memory
			})
			
		if not np.isnan(avg_loss):
			if avg_loss < best_train_loss:
				best_train_loss = avg_loss
				train_patience_counter = 0
			else:
				train_patience_counter += 1
			
			if train_patience_counter >= cfg.exp.patience:
				break
		else:
			print(f"Warning: Skipping early stopping check for epoch {epoch+1} due to NaN average loss.")

	print("Training finished.")
	#torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))
	return model


def test_model(args, cfg, graphs, smiles, targets, model, transducer, prefix=None):
	model.eval()

	torch.cuda.reset_peak_memory_stats(args.device)  # GPU 메모리 초기화
	torch.cuda.synchronize()
	start_time = time.time()  # 시간 측정 시작

	test_loader = DataLoader(
		graphs, 
		batch_size=cfg.exp.batch_size, 
		shuffle=False,
		num_workers=0,
		worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id)
	)

	preds = {'preds': [], 'gt': targets, 'query_smiles': smiles, 'anchor_idxs': [], 'anchor_smiles': []}

	with torch.no_grad():
		for batch in test_loader:
			batch = batch.to(args.device)
			query_embeddings = model.extract_embedding(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

			candidates_emb, anchor_weights, attn_mask, candidates_indices, candidates_smiles = transducer.choose_multiple_anchors(query_embeddings)
			logits = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			if args.task == 'classification':
				preds_y = torch.sigmoid(logits)
			else:
				preds_y = logits

			preds['preds'].append(preds_y)
			preds['anchor_idxs'].extend(candidates_indices)
			preds['anchor_smiles'].append(candidates_smiles)

	torch.cuda.synchronize()
	total_time = time.time() - start_time
	avg_time_per_sample = total_time / len(smiles)
	peak_memory = torch.cuda.max_memory_allocated(args.device) / (1024 ** 2)  # MB

	print(f"\n[{prefix}] Inference Time per sample: {avg_time_per_sample:.4f} s")
	print(f"[{prefix}] Peak GPU Memory: {peak_memory:.2f} MB")

	if args.wandb_log:
		wandb.log({
			f"{prefix}_inference_time_per_sample": avg_time_per_sample,
			f"{prefix}_peak_memory_MB": peak_memory
		})

	preds['preds'] = torch.cat(preds['preds'], dim=0).detach().cpu().numpy()
	preds['gt'] = np.array(preds['gt']).reshape(-1, 1 if args.task == 'classification' else preds['preds'].shape[1])

	results = calculate_metrics(preds['gt'], preds['preds'], prefix=prefix, task=args.task)

	if args.wandb_log:
		wandb.summary.update(results)

	save_results(args, results, prefix)

	return preds
