import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings

from chemprop.models import MoleculeModel
from chemprop.args import TrainArgs


def define_model(args, cfg, num_tasks):
	if cfg.model.model_type == 'multi_anchor_chemprop':
		if cfg.model.path is not None:
			path_pattern = glob.glob(os.path.join(
				cfg.model.path,
				args.dataset_name,
				args.dataset_split_type,
				args.prop_type,
				str(args.seed),
				'*',
				'ckpts/fold_0/model_0/model.pt'
			))
			if not path_pattern:
				raise SystemExit(
					"No encoder checkpoint found matching the provided cfg.model.path pattern."
				)
			if len(path_pattern) > 1:
				path_pattern.sort(key=lambda p: (os.path.getmtime(p), p))
				selected_ckpt = path_pattern[-1]
				warnings.warn(
					f"Multiple encoder checkpoints found; using the most recent one: {selected_ckpt}",
					RuntimeWarning
				)
			else:
				selected_ckpt = path_pattern[0]
			cfg.model.path = selected_ckpt
			print(cfg.model.path)
			checkpoint = torch.load(os.path.join(cfg.model.path), map_location='cpu')
			checkpoint_args = checkpoint['args']
			model = MoleculeModel(checkpoint_args)
			model.args = checkpoint_args
			model.load_state_dict(checkpoint['state_dict'])
		else:
			train_path = os.path.join(args.data_dir, args.dataset_name, args.prop_type, 'train_featurized.csv')
			trainargs = TrainArgs().parse_args([
								'--dataset_type', 'regression',
								'--data_path', train_path,])
			#trainargs.num_tasks=str(num_tasks)
			model = MoleculeModel(trainargs)
			model.args = trainargs

		pretrained_encoder = model 
	
		model = MultiAnchorChempropPredictor(
			latent_dim=cfg.model.latent_dim,
			num_tasks=num_tasks,
			num_heads=cfg.model.num_heads,
			pretrained_encoder=pretrained_encoder,
			use_anchor_weights=cfg.transducer.use_anchor_weights
			# Add any other necessary params from model_config if the constructor needs them
		)

	return model



class MultiAnchorChempropPredictor(nn.Module):
	"""
	Multi-anchor predictor using a pretrained SMILES encoder (e.g., Smi-TED).
	Fuses candidate latent embeddings via cross-attention.
	Accepts anchor weights and attention masks for advanced anchoring techniques.
	"""
	def __init__(self, latent_dim, num_tasks, num_heads=4, pretrained_encoder=None,
				 use_anchor_weights=False, num_candidates=5): # Added flags
		super(MultiAnchorChempropPredictor, self).__init__()
		self.latent_dim = latent_dim
		self.num_tasks = num_tasks
		self.use_anchor_weights = use_anchor_weights # Store flag
		self.num_candidates = num_candidates # Store k

		self.encoder = pretrained_encoder

		# Attention mechanism
		self.attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=num_heads, batch_first=True)

		# Fusion layer input dimension depends on whether weights are used
		fusion_input_dim = latent_dim * 2 # Query embedding + Attention output
		if self.use_anchor_weights:
			fusion_input_dim += self.num_candidates # Add k for the anchor weights

		# Fusion MLP (Adjusted input dim)
		self.fusion = nn.Linear(fusion_input_dim, latent_dim)
		self.output_layer = nn.Linear(latent_dim, num_tasks)

		# Expose the embedding extractor
		self.obs_trunk = self.extract_embedding

	def extract_embedding(self, batch_mol_graph):

		embeddings = self.encoder.encoder([batch_mol_graph])
		return embeddings



	def forward(self, query_embedding, candidate_embeddings, anchor_weights=None, attention_mask=None):
		"""
		Performs multi-anchor fusion via cross-attention.

		Args:
			query_embedding (torch.Tensor): Shape (batch, latent_dim)
			candidate_embeddings (torch.Tensor): Shape (batch, k, latent_dim)
			anchor_weights (torch.Tensor, optional): Shape (batch, k). Defaults to None.
			attention_mask (torch.Tensor, optional): Bool tensor, shape (batch, k).
													 True indicates key should be ignored. Defaults to None.

		Returns:
			torch.Tensor: Predictions of shape (batch, num_tasks)
		"""
		# Apply cross-attention
		query_attn = query_embedding.unsqueeze(1)  # (batch, 1, latent_dim)

		# Pass attention_mask to key_padding_mask
		attn_output, _ = self.attention(query=query_attn,
										key=candidate_embeddings,
										value=candidate_embeddings,
										key_padding_mask=attention_mask) # Pass mask here

		attn_output = attn_output.squeeze(1)  # (batch, latent_dim)

		# Fuse the query with the attended candidate information.
		fused_input = [query_embedding, attn_output]

		# Optionally concatenate anchor weights
		if self.use_anchor_weights:
			if anchor_weights is None:
				# If weights are expected but not provided, raise error or use defaults?
				# For consistency with networks_new.py, raise error.
				raise ValueError("Anchor weights must be provided when use_anchor_weights is True")
			# Check shape consistency
			if anchor_weights.shape[1] != self.num_candidates:
				# Or check against candidate_embeddings.shape[1] if k can vary?
				warnings.warn(f"Received {anchor_weights.shape[1]} anchor weights, expected {self.num_candidates}. Using received count for fusion.")
				# Adjust fusion layer dynamically? Simpler to ensure k matches num_candidates.
				# For now, assume num_candidates is the intended dimension k.
				if anchor_weights.shape[1] != candidate_embeddings.shape[1]:
					raise ValueError(f"Shape mismatch: anchor_weights ({anchor_weights.shape[1]}) != candidates ({candidate_embeddings.shape[1]})")

			fused_input.append(anchor_weights)

		fused = torch.cat(fused_input, dim=-1) # Shape (batch, latent_dim*2 [+ k])

		# Apply fusion MLP and output layer
		fused_hidden = F.relu(self.fusion(fused))  # (batch, latent_dim)
		output = self.output_layer(fused_hidden) # (batch, num_tasks)
		return output
