import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Load data
train_data = np.load("./input/cosmic_train.npy")
val_data = np.load("./input/cosmic_val.npy")


class CosmicRayDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx, 0:1, :, :].astype(np.float32)  # Image
        y = self.data[idx, 1:2, :, :].astype(np.float32)  # Ground truth
        m = self.data[idx, 2:3, :, :].astype(np.float32)  # Ignore mask
        return torch.tensor(x), torch.tensor(y), torch.tensor(m)


train_dataset = CosmicRayDataset(train_data)
val_dataset = CosmicRayDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.ReLU()(self.conv2(x))
        x = self.conv3(x)
        return torch.sigmoid(x)  # Output shape [B, 1, 256, 256]


# Function to train a model
def train_model(model):
    criterion = nn.BCELoss(reduction="none")
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(5):
        model.train()
        for x, y, m in tqdm(train_loader):
            optimizer.zero_grad()
            pred = model(x)
            valid = (m == 0).float()
            loss_per_pixel = criterion(pred, y)
            loss = (loss_per_pixel * valid).sum() / valid.sum().clamp_min(1.0)
            loss.backward()
            optimizer.step()


# Train multiple models for ensembling
num_models = 3
models = [SimpleCNN() for _ in range(num_models)]
for model in models:
    train_model(model)


# Evaluation
def evaluate(models):
    for model in models:
        model.eval()

    pred_probs = []
    with torch.no_grad():
        for x, y, m in val_loader:
            preds = [model(x) for model in models]
            avg_pred = torch.mean(torch.stack(preds), dim=0)
            pred_probs.append(avg_pred.cpu().numpy())

    if pred_probs:  # Check if there are any predictions collected
        pred_probs = np.concatenate(pred_probs, axis=0).ravel()
        y_true = val_data[:, 1].ravel()
        m_mask = val_data[:, 2].ravel() == 0
        masked_auroc = roc_auc_score(y_true[m_mask], pred_probs[m_mask])

        # Save predictions
        submission = pred_probs.astype(np.float32)
        np.savetxt("./working/submission.csv", submission, delimiter=",")

        print(f"Masked AUROC: {masked_auroc}")
    else:
        print("No predictions were collected.")


evaluate(models)
