import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler


class NNModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NNModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class NN:
    def __init__(self, target):
        self.target = target
        self.dtypes = None
        self.num_classes = None
        self.model = None
        self.input_size = None
        self.scaler = StandardScaler()
        self.label_encoders = {}

    def _detect_dtypes(self, data):
        dtypes = []

        for col in data.columns:
            if pd.api.types.is_integer_dtype(data[col]):
                if col != self.target:
                    dtypes.append((col, "int8"))
                data[col] = data[col].astype("int64")
            elif pd.api.types.is_float_dtype(data[col]):
                if col != self.target:
                    dtypes.append((col, "float"))
                data[col] = data[col].astype("float64")
            else:
                if col != self.target:
                    dtypes.append((col, "category"))
                data[col] = data[col].astype("category")

        return dtypes, data

    def load_data(self, data):
        if self.target not in data.columns:
            raise ValueError(f"Target column '{self.target}' not found in the dataset.")

        self.dtypes, data = self._detect_dtypes(data)
        print(f"Detected feature types: {self.dtypes}")
        self.num_classes = data[self.target].nunique()
        print(f"Detected {self.num_classes} classes.")

        train_data, temp_data = train_test_split(data, test_size=0.3, random_state=42)
        val_data, test_data = train_test_split(
            temp_data, test_size=0.5, random_state=42
        )
        return train_data, val_data, test_data

    def _preprocess_features(self, df):
        """
        Preprocess the features:
        - Label encode categorical features.
        - Scale continuous features.
        """
        processed_data = df.copy()

        for col, dtype in self.dtypes:
            if dtype == "category":
                if col not in self.label_encoders:
                    self.label_encoders[col] = LabelEncoder()
                    processed_data[col] = self.label_encoders[col].fit_transform(
                        processed_data[col]
                    )
                else:
                    print(f"Using existing label encoder for {col}")
                    processed_data[col] = self.label_encoders[col].transform(
                        processed_data[col]
                    )

        continuous_cols = [col for col, dtype in self.dtypes if dtype == "float"]
        if continuous_cols:
            processed_data[continuous_cols] = self.scaler.fit_transform(
                processed_data[continuous_cols]
            )

        return processed_data

    def _build_model(self):
        return NNModel(self.input_size, self.num_classes)

    def train_model(
        self,
        train_data,
        val_data,
        model_path,
        epochs=100,
        batch_size=32,
        lr=0.0001,
        patience=10,
    ):
        if self.num_classes is None or self.dtypes is None:
            raise ValueError(
                "Number of classes or data types have not been detected. Ensure load_data() has been called."
            )

        X_train = train_data.drop(columns=[self.target])
        y_train = train_data[self.target]
        X_train = self._preprocess_features(X_train)
        y_train_encoded = LabelEncoder().fit_transform(y_train)
        X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
        y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)

        X_val = val_data.drop(columns=[self.target])
        y_val = val_data[self.target]
        X_val = self._preprocess_features(X_val)
        y_val_encoded = LabelEncoder().fit_transform(y_val)
        X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32)
        y_val_tensor = torch.tensor(y_val_encoded, dtype=torch.long)

        self.input_size = X_train_tensor.shape[1]
        self.model = self._build_model()

        train_dataset = CustomDataset(X_train_tensor, y_train_tensor)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_dataset = CustomDataset(X_val_tensor, y_val_tensor)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        best_val_loss = float("inf")
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            for inputs, targets in train_loader:
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                optimizer.step()
                train_loss += loss.item()

            self.model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    outputs = self.model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()

            val_loss /= len(val_loader)
            print(
                f"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}"
            )

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_model(model_path)
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping")
                    break

    def predict(self, test_data):
        if not self.model:
            raise ValueError("Model has not been trained yet.")

        X_test = test_data.drop(columns=[self.target])
        X_test = self._preprocess_features(X_test)

        X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X_test_tensor)
            _, predicted = torch.max(outputs, 1)
        return predicted.numpy()

    def evaluate_model(self, test_data):
        y_test = test_data[self.target].values
        y_test = LabelEncoder().fit_transform(y_test)
        print(f"Ground truth: {y_test}")
        y_pred = self.predict(test_data)
        print(f"Predictions: {y_pred}")
        accuracy = (y_pred == y_test).mean()
        return accuracy

    def predict_proba(self, test_data):
        if not self.model:
            raise ValueError("Model has not been trained yet.")

        if self.target in test_data.columns:
            X_test = test_data.drop(columns=[self.target])
        else:
            X_test = test_data
        X_test = self._preprocess_features(X_test)
        X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X_test_tensor)
        return outputs.numpy()

    def save_model(self, file_path):
        if not self.model:
            raise ValueError("No model available to save. Train the model first.")

        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "input_size": self.input_size,
            "num_classes": self.num_classes,
            "dtypes": self.dtypes,
        }
        torch.save(checkpoint, file_path)
        print(f"Model and input size saved to {file_path}")

    def load_model(self, file_path):
        checkpoint = torch.load(file_path)

        if (
            "input_size" in checkpoint
            and "num_classes" in checkpoint
            and "dtypes" in checkpoint
        ):
            self.input_size = checkpoint["input_size"]
            self.num_classes = checkpoint["num_classes"]
            self.dtypes = checkpoint["dtypes"]
        else:
            raise ValueError(
                "Checkpoint does not contain 'input_size' or 'num_classes' or 'dtypes'."
            )

        self.model = self._build_model()

        self.model.load_state_dict(checkpoint["model_state_dict"])
        print(
            f"Model loaded from {file_path} with input size {self.input_size} and {self.num_classes} classes."
        )
        print(f"Detected feature types: {self.dtypes}")
