# networks_smiles.py (Adapted for new anchoring inputs)
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
import warnings 

from baselines.smi_ted.smi_ted_light.load import *

def define_model(args, cfg, num_tasks):

	if cfg.model.model_type == 'multi_anchor_smi':
		# --- Parameters specific to MultiAnchorSMILESPredictor ---
		latent_dim = cfg.model.latent_dim        
		num_heads = cfg.model.num_heads
		encoder_folder = cfg.model.path
		encoder_ckpt = cfg.model.ckpt_path


		print(f"Loading SMI-TED encoder from folder: {encoder_folder}, ckpt: {encoder_ckpt}")
		try:
			pretrained_encoder = load_smi_ted(folder=encoder_folder, ckpt_filename=encoder_ckpt, seed=args.seed)
			pretrained_encoder.train() # Set to train mode (might be relevant for dropout, etc.)
		except:
			path_pattern = glob.glob(os.path.join(cfg.model.path, args.prop_type, str(args.seed), '*', 'ckpts',  'final.pt'))
			if len(path_pattern) > 1:
				raise ValueError 
			encoder_ckpt = path_pattern[0]
			print(encoder_ckpt)
			encoder_ckpt = encoder_ckpt.split(cfg.model.path)[1][1:]
			pretrained_encoder = load_smi_ted(folder=encoder_folder, ckpt_filename=encoder_ckpt, seed=args.seed)
			pretrained_encoder.train() # Set to train mode (might be relevant for dropout, etc.)
		


		model = MultiAnchorSMILESPredictor(
			latent_dim=latent_dim,
			num_tasks=num_tasks,
			num_heads=num_heads,
			pretrained_encoder=pretrained_encoder,
			use_anchor_weights=cfg.transducer.use_anchor_weights
			# Add any other necessary params from model_config if the constructor needs them
		)

	return model



class MultiAnchorSMILESPredictor(nn.Module):
	"""
	Multi-anchor predictor using a pretrained SMILES encoder (e.g., Smi-TED).
	Fuses candidate latent embeddings via cross-attention.
	Accepts anchor weights and attention masks for advanced anchoring techniques.
	"""
	def __init__(self, latent_dim, num_tasks, num_heads=4, pretrained_encoder=None,
				 use_anchor_weights=False, num_candidates=5): # Added flags
		super(MultiAnchorSMILESPredictor, self).__init__()
		self.latent_dim = latent_dim
		self.num_tasks = num_tasks
		self.use_anchor_weights = use_anchor_weights # Store flag
		self.num_candidates = num_candidates # Store k

		if pretrained_encoder is None:
			raise ValueError("A pretrained SMILES encoder must be provided.")
		# Use the provided pretrained encoder for embedding extraction.
		self.encoder = pretrained_encoder
		# Assumes encoder has an 'extract_embeddings' method for SMILES strings

		# Attention mechanism
		self.attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=num_heads, batch_first=True)

		# Fusion layer input dimension depends on whether weights are used
		fusion_input_dim = latent_dim * 2 # Query embedding + Attention output
		if self.use_anchor_weights:
			fusion_input_dim += self.num_candidates # Add k for the anchor weights

		# Fusion MLP (Adjusted input dim)
		self.fusion = nn.Linear(fusion_input_dim, latent_dim)
		self.output_layer = nn.Linear(latent_dim, num_tasks)

		# Expose the embedding extractor
		# Note: The input signature here is specific to SMILES
		self.obs_trunk = self.extract_embedding_wrapper

	def extract_embedding_wrapper(self, smiles_batch, *args, **kwargs):
		"""
		Wrapper to match potential expected signature by other parts (like memory bank builder),
		assuming it primarily needs the embeddings.
		It expects the first argument to be the SMILES batch.
		"""
		# Smi-TED encoder might return multiple outputs (e.g., CLS token, avg pooling, token embeddings)
		# Assuming the desired graph/molecule-level embedding is at index 2 (like in trainer_new_scheduler_smiles.py build_memory_bank_gnn)
		# Adjust index if your Smi-TED model returns embeddings differently.
		try:
			# The specific index [2] might need adjustment based on your smi-ted model.
			embeddings = self.encoder.extract_embeddings(smiles_batch)[2]
		except Exception as e:
			print(f"Error in extract_embedding_wrapper during embedding extraction: {e}")
			# Handle error, maybe return None or raise, depending on downstream requirements
			# Example: Check if encoder has the method
			if not hasattr(self.encoder, 'extract_embeddings'):
				raise AttributeError("Provided encoder object does not have 'extract_embeddings' method.") from e
			# Example: Check output structure if method exists
			try:
				output = self.encoder.extract_embeddings(smiles_batch)
				print(f"Encoder output type: {type(output)}, value (partial): {str(output)[:100]}...")
				if isinstance(output, (list, tuple)) and len(output) > 2:
					print(f"Accessing index 2, shape: {output[2].shape if hasattr(output[2], 'shape') else 'N/A'}")
				else:
					print("Encoder output is not a sequence of length > 2.")
			except Exception as inner_e:
				print(f"Error inspecting encoder output: {inner_e}")
			raise RuntimeError("Failed to extract embeddings. Check encoder structure and output.") from e

		return embeddings


	def forward(self, query_embedding, candidate_embeddings, anchor_weights=None, attention_mask=None):
		"""
		Performs multi-anchor fusion via cross-attention.

		Args:
			query_embedding (torch.Tensor): Shape (batch, latent_dim)
			candidate_embeddings (torch.Tensor): Shape (batch, k, latent_dim)
			anchor_weights (torch.Tensor, optional): Shape (batch, k). Defaults to None.
			attention_mask (torch.Tensor, optional): Bool tensor, shape (batch, k).
													 True indicates key should be ignored. Defaults to None.

		Returns:
			torch.Tensor: Predictions of shape (batch, num_tasks)
		"""
		# Apply cross-attention
		query_attn = query_embedding.unsqueeze(1)  # (batch, 1, latent_dim)

		# Pass attention_mask to key_padding_mask
		attn_output, _ = self.attention(query=query_attn,
										key=candidate_embeddings,
										value=candidate_embeddings,
										key_padding_mask=attention_mask) # Pass mask here

		attn_output = attn_output.squeeze(1)  # (batch, latent_dim)

		# Fuse the query with the attended candidate information.
		fused_input = [query_embedding, attn_output]

		# Optionally concatenate anchor weights
		if self.use_anchor_weights:
			if anchor_weights is None:
				# If weights are expected but not provided, raise error or use defaults?
				# For consistency with networks_new.py, raise error.
				raise ValueError("Anchor weights must be provided when use_anchor_weights is True")
			# Check shape consistency
			if anchor_weights.shape[1] != self.num_candidates:
				 # Or check against candidate_embeddings.shape[1] if k can vary?
				warnings.warn(f"Received {anchor_weights.shape[1]} anchor weights, expected {self.num_candidates}. Using received count for fusion.")
				# Adjust fusion layer dynamically? Simpler to ensure k matches num_candidates.
				# For now, assume num_candidates is the intended dimension k.
				if anchor_weights.shape[1] != candidate_embeddings.shape[1]:
					raise ValueError(f"Shape mismatch: anchor_weights ({anchor_weights.shape[1]}) != candidates ({candidate_embeddings.shape[1]})")

			fused_input.append(anchor_weights)
		fused = torch.cat(fused_input, dim=-1) # Shape (batch, latent_dim*2 [+ k])

		# Apply fusion MLP and output layer
		fused_hidden = F.relu(self.fusion(fused))  # (batch, latent_dim)
		output = self.output_layer(fused_hidden) # (batch, num_tasks)
		return output