import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

# Load data
x_train = np.load("./input/x_train.npz")["arr_0"]
y_train = np.load("./input/y_train.npz")["arr_0"]
x_val = np.load("./input/x_val.npz")["arr_0"]
y_val = np.load("./input/y_val.npz")["arr_0"]


# Create a simple feedforward neural network with Batch Normalization
class ChromatinPredictor(nn.Module):
    def __init__(self):
        super(ChromatinPredictor, self).__init__()
        self.fc1 = nn.Linear(1000 * 4, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 36)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.relu(self.bn2(self.fc2(x)))
        x = self.sigmoid(self.fc3(x))
        return x


# Prepare the data for PyTorch
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

# Initialize model, loss function, and optimizer
model = ChromatinPredictor()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# K-fold Cross Validation
kf = KFold(n_splits=5)
auc_scores = []

for train_index, val_index in kf.split(x_train_tensor):
    x_train_fold, x_val_fold = x_train_tensor[train_index], x_train_tensor[val_index]
    y_train_fold, y_val_fold = y_train_tensor[train_index], y_train_tensor[val_index]

    # Training the model
    model.train()
    for epoch in range(200):  # Increased training for 200 epochs
        optimizer.zero_grad()
        outputs = model(x_train_fold)
        loss = criterion(outputs, y_train_fold)
        loss.backward()
        optimizer.step()

    # Evaluating the model
    model.eval()
    with torch.no_grad():
        val_outputs = model(x_val_fold)
        val_auc = roc_auc_score(
            y_val_fold.numpy(), val_outputs.numpy(), average="macro"
        )
        auc_scores.append(val_auc)

# Print the average validation AUC
print("Average Validation AUC:", np.mean(auc_scores))

# Save predictions for submission
model.eval()
with torch.no_grad():
    val_outputs = model(torch.tensor(x_val, dtype=torch.float32))
    submission = pd.DataFrame(
        val_outputs.numpy(), columns=[f"chromatin_{i}" for i in range(36)]
    )
submission.to_csv("./working/submission.csv", index=False)
