import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import mean_squared_error
import numpy as  np
import os
import wandb
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_size, n_layers, hidden_size):
        super(MLP, self).__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.ReLU())
            input_size = hidden_size
        layers.append(nn.Linear(hidden_size, 1))  # Output layer
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return torch.squeeze( self.network(x) )


def train_mlp(train_dataloader, test_dataloader, input_size,wandb_run=None, n_layers=3,
              hidden_size=64, n_epochs=20, learning_rate=0.001):


    # Initialize the model, loss function, and optimizer
    model = MLP(input_size, n_layers, hidden_size).to('cuda')  # Move model to GPU
    print( "# params: " ,sum(p.numel() for p in model.parameters() if p.requires_grad) )
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    # Training loop
    for epoch in tqdm(range(n_epochs), desc="Epochs", dynamic_ncols=True):
        model.train()
        running_loss = 0
        n_data = 0
        for batch in tqdm(train_dataloader, desc="Batches", leave=False, dynamic_ncols=True):
            X_batch, y_batch = batch
            X_batch, y_batch = X_batch.to('cuda'), y_batch.to('cuda')

            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()*X_batch.shape[0]
            n_data = n_data + X_batch.shape[0]

        train_mse = running_loss/n_data

        # Calculate MSE for train and test sets
        model.eval()
        val_dict = {}
        with torch.no_grad():
            for test_dataloader_tmp in test_dataloader:
                cl = test_dataloader_tmp[1]
                # Calculate Test MSE
                test_targets = []
                test_preds = []
                for X_batch, y_batch in test_dataloader_tmp[0]:
                    X_batch = X_batch.to('cuda')
                    y_pred = model(X_batch)
                    test_preds.append(y_pred.cpu())
                    test_targets.append(y_batch)
                test_mse = mean_squared_error(
                    torch.cat(test_targets).numpy(), torch.cat(test_preds).numpy()
                )
                val_dict[f'Test mse cl ={cl}'] = test_mse

        val_dict["epoch"] = epoch + 1
        val_dict["train_loss"] = train_mse
        if wandb_run:
            wandb_run.log(val_dict)
        else:
            for key, value in val_dict.items():
                print(f"{key}: {value}")








