import os
import wandb
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
from tqdm import tqdm
from .util import *

from rdkit import Chem
from chemprop.features.featurization import MolGraph, BatchMolGraph


class GraphDataset(torch.utils.data.Dataset):
	def __init__(self, graphs, smiles, targets):
		self.graphs = [MolGraph(Chem.MolFromSmiles(s)) for s in smiles]  # ← MolGraph는 BatchMolGraph 내부에서 쓰는 단일 단위
		self.smiles = smiles
		self.targets = torch.from_numpy(np.array(targets, dtype=np.float32))

	def __len__(self):
		return len(self.graphs)

	def __getitem__(self, idx):
		return self.graphs[idx], self.smiles[idx], self.targets[idx]


def graph_collate_fn(batch):
	graphs, smiles, targets = zip(*batch)
	batched_graph = BatchMolGraph(graphs)  # 이미 준비된 MolGraph로 빠르게 생성
	return batched_graph, list(smiles), torch.stack(targets)

# def build_memory_bank_gnn(model, graphs, smiles, targets, batch_size):
# 	model.eval()
# 	all_embeddings = []
# 	graph_dataset = GraphDataset(graphs, targets)
# 	graph_loader = DataLoader(graph_dataset, batch_size=batch_size, shuffle=False, collate_fn=graph_collate_fn)
# 	#smiles_loader = DataLoader(smiles_dataset, batch_size=batch_size, shuffle=False)
# 	model.eval()
# 	with torch.no_grad():
# 		for batch in graph_loader:
# 			graph_batch, label_batch = batch
# 			emb = model.extract_embedding(graph_batch)
# 			all_embeddings.append(emb)
# 	memory_bank = torch.cat(all_embeddings, dim=0)
# 	return memory_bank, smiles

def build_memory_bank_gnn(model, graph_loader, batch_size):
	start_time = time.time()

	model.eval()
	model.encoder.eval()

	all_embeddings = []
	all_smiles = []
	#graph_dataset = GraphDataset(graphs, targets)
	#graph_loader = DataLoader(graph_dataset, batch_size=batch_size, shuffle=False, collate_fn=graph_collate_fn, num_workers=0)

	with torch.no_grad():
		for batch in graph_loader:
			graph_batch, smiles_batch, label_batch = batch
			emb = model.extract_embedding(graph_batch)
			all_embeddings.append(emb)
			all_smiles += smiles_batch

	memory_bank = torch.cat(all_embeddings, dim=0)
	return memory_bank, all_smiles



def train_model_gnn_with_memory(args, cfg, graphs, smiles, targets, model, transducer):
	optimizer = optim.Adam(model.parameters(), lr=cfg.exp.learning_rate)
	scheduler = ExponentialLR(optimizer, gamma=cfg.exp.scheduler.gamma)
	print(f"Using Adam optimizer with LR={cfg.exp.learning_rate} and ExponentialLR scheduler with gamma={cfg.exp.scheduler.gamma}")
	epoch_loss = []
	best_train_loss = float('inf')
	train_patience_counter = 0
 
 	
	#train_loader = DataLoader(smiles_dataset, batch_size=cfg.exp.batch_size, shuffle=True, collate_fn=lambda x: x)
	graph_dataset = GraphDataset(graphs, smiles, targets)
	train_loader = DataLoader(
							graph_dataset, 
							batch_size=cfg.exp.batch_size, 
							shuffle=True,
							num_workers=0,
							worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id),
	   						collate_fn=graph_collate_fn
	)
	
	pbar = tqdm(range(cfg.exp.num_epochs), desc="Training Epochs")
	for epoch in pbar:
		model.train()
		if epoch > 0 and epoch % cfg.exp.memory_update_interval == 0:
			try:
				current_memory_bank, _ = build_memory_bank_gnn(model, train_loader, batch_size=cfg.exp.batch_size)
				transducer.train_embs = current_memory_bank
			except Exception as build_e: print(f"Error updating memory bank: {build_e}")
		
		running_loss = 0
		for i, (graph, smiles, target) in enumerate(train_loader):
			optimizer.zero_grad()
			query_embeddings = model.extract_embedding(graph)
			candidates_emb, anchor_weights, attn_mask, _, _ = transducer.choose_multiple_anchors(query_embeddings, return_smiles=(epoch+1 == cfg.exp.num_epochs))
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)
			gt_y = target.float().to(args.device)
			if gt_y.shape != preds_y.shape:
				gt_y = gt_y.view_as(preds_y)

			if args.task == 'classification':
				loss = F.binary_cross_entropy_with_logits(preds_y.view(-1), gt_y.view(-1))
			else:
				loss = F.l1_loss(preds_y, gt_y)

			loss.backward()
			optimizer.step()
			running_loss += loss.item()

		
		avg_loss = running_loss / (i+1)

		
		epoch_loss.append(avg_loss)
		if args.wandb_log: wandb.log({'train_loss': avg_loss})
		current_lr = optimizer.param_groups[0]['lr']
		pbar.set_postfix(loss=f'{avg_loss:.6f}', lr=f'{current_lr:.6f}')

		scheduler.step()
  
		if not np.isnan(avg_loss):
			if avg_loss < best_train_loss:
				best_train_loss = avg_loss
				train_patience_counter = 0
			else:
				train_patience_counter += 1
			
			if train_patience_counter >= cfg.exp.patience:
				break
		
		else:
			print(f"Warning: Skipping early stopping check for epoch {epoch+1} due to NaN average loss.")

	print("Training finished.")
	torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))

	return model


import time
import torch

def test_model(args, cfg, graphs, smiles, targets, model, transducer, prefix=None):
	model.eval()

	torch.cuda.reset_peak_memory_stats(args.device)
	torch.cuda.synchronize()
	start_time = time.time()

	graph_dataset = GraphDataset(graphs, smiles, targets)
	test_loader = DataLoader(
		graph_dataset, 
		batch_size=cfg.exp.batch_size, 
		shuffle=False,
		num_workers=0,
		worker_init_fn=lambda worker_id: np.random.seed(args.seed + worker_id),
		collate_fn=graph_collate_fn
	)

	preds = {'preds': [], 'gt': targets, 'query_smiles': smiles, 'anchor_idxs': [], 'anchor_smiles': []}

	with torch.no_grad():
		for graph, smiles_batch, targets_batch in test_loader:
			query_embeddings = model.extract_embedding(graph)
			candidates_emb, anchor_weights, attn_mask, candidates_indices, candidates_smiles = transducer.choose_multiple_anchors(
				query_embeddings, return_smiles=True)
			preds_y = model(query_embeddings, candidates_emb, anchor_weights=anchor_weights, attention_mask=attn_mask)

			if args.task == 'classification':
				preds_prob = torch.sigmoid(preds_y)
				preds['preds'].append(preds_prob)
			else:
				preds['preds'].append(preds_y)

			preds['anchor_idxs'].extend(candidates_indices)
			preds['anchor_smiles'].append(candidates_smiles)


	preds['preds'] = torch.cat(preds['preds'], dim=0).detach().cpu().numpy()
	preds['gt'] = np.array(preds['gt'])

	results = calculate_metrics(preds['gt'], preds['preds'], prefix=prefix, task=args.task)

	if args.wandb_log:
		wandb.summary.update(results)

	save_results(args, results, prefix)

	return preds

