import os
import time
import torch.nn as nn
from torch_geometric.data import Dataset, DataLoader
from tqdm import tqdm
import wandb
import time
from model import *
from util import compute_leaderboard_score


def define_model(cfg, num_tasks):
	model = GNN_graphpred(
							num_layer = cfg.model.num_layer,
							emb_dim = cfg.model.emb_dim,
							num_tasks = num_tasks,
							JK=cfg.model.JK,
							drop_ratio=cfg.model.dropout_ratio,
							graph_pooling=cfg.model.graph_pooling,
							gnn_type=cfg.model.gnn_type
	)
	return model

class GraphDataset(Dataset):
	def __init__(self, graph, target):
		self.graph = graph

	def __len__(self):
		return len(self.graph)    
	
	def __getitem__(self, idx):
		graph = self.graph[idx]
		return graph
	
	
def train_model(args, cfg, train_dataset, model, optimizer):
    train_dataset = GraphDataset(train_dataset['graphs'], train_dataset['targets'])
    train_loader = DataLoader(train_dataset, batch_size=cfg.exp.batch_size, shuffle=True)
	
    model.train()
    criterion = nn.MSELoss(reduction="none")
	
    epoch_losses = []
    pbar = tqdm(range(cfg.exp.num_epochs), desc="Training...")
    for epoch in pbar:
        running_loss = 0
        num_batches = 0
        for step, batch in enumerate(train_loader):
            batch = batch.to(args.device)
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y = torch.tensor(batch.y, dtype=torch.float64, device=args.device).view(pred.shape)
            is_valid = y**2 > 0
            loss_mat = criterion(pred.double(), y.double())
            loss_mat = torch.where(is_valid, loss_mat, torch.zeros_like(loss_mat))
            optimizer.zero_grad()
            loss = torch.sum(loss_mat) / torch.sum(is_valid)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_batches += 1
		
        epoch_loss = running_loss / num_batches
        epoch_losses.append(epoch_loss)
        if args.wandb_log: wandb.log({'train_loss': epoch_loss})
        pbar.set_postfix(loss=f'{epoch_loss:.6f}')
	
    torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))
    return model



def test_model(cfg, model, device, test_dataset):
    """Evaluation function for the GNN model."""
    test_dataset = GraphDataset(test_dataset['graphs'], test_dataset['targets'])
    test_loader = DataLoader(test_dataset, batch_size=cfg.exp.batch_size, shuffle=False)
    model.eval()
    y_true = []
    y_scores = []

    for step, batch in enumerate(test_loader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        # Ensure the shapes of true and predicted values match
        y_true.append(torch.tensor(batch.y).view(pred.shape))
        y_scores.append(pred)

    # Concatenate all batch outputs into single arrays
    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_preds = torch.cat(y_scores, dim=0).cpu().numpy()

    return y_true, y_preds


