import anndata
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import scanpy as sc
import scprep
import sklearn.metrics

# Load data
adata = anndata.read_h5ad("./input/1k_pbmc_processed.h5ad")
train_data = adata.obsm["train"].toarray()  # Convert to dense format
test_data = adata.obsm["test"].toarray()  # Convert to dense format


# Define Autoencoder with Dropout
class Autoencoder(nn.Module):
    def __init__(self, input_dim):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),  # Added dropout layer
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),  # Added dropout layer
        )
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),  # Added dropout layer
            nn.Linear(128, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# Prepare data
X_train = torch.tensor(train_data, dtype=torch.float32)
X_test = torch.tensor(test_data, dtype=torch.float32)

# Train the autoencoder
model = Autoencoder(input_dim=X_train.shape[1])
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", patience=5, verbose=True
)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    reconstructed = model(X_train)
    loss = criterion(reconstructed, X_train)
    loss.backward()
    optimizer.step()

    # Step the scheduler
    scheduler.step(loss)

# Denoise the test data
model.eval()
with torch.no_grad():
    denoised = model(X_test).numpy()

# Store the denoised data
adata.obsm["denoised"] = denoised

# Save predictions for evaluation
submission = pd.DataFrame(denoised)
submission.to_csv("./working/submission.csv", index=False)


# Evaluate the performance
def mse(adata):
    test_data = anndata.AnnData(X=adata.obsm["test"], obs=adata.obs, var=adata.var)
    denoised_data = anndata.AnnData(
        X=adata.obsm["denoised"], obs=adata.obs, var=adata.var
    )

    target_sum = 10000
    sc.pp.normalize_total(test_data, target_sum=target_sum)
    sc.pp.log1p(test_data)
    sc.pp.normalize_total(denoised_data, target_sum=target_sum)
    sc.pp.log1p(denoised_data)

    error = sklearn.metrics.mean_squared_error(
        scprep.utils.toarray(test_data.X), denoised_data.X
    )
    return error


# Print the evaluation metric
error_metric = mse(adata)
print(f"Mean Squared Error: {error_metric}")
