
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from l2_weights import l2_weights
from metrics import evaluate


def activation_fn(x, mode):
    if mode == 1:
        return torch.sigmoid(x)
    elif mode == 2:
        return torch.relu(x)
    elif mode == 3:
        return torch.selu(x)
    else:
        raise ValueError("Invalid activation function.")


class WarmupLayer(nn.Module):
    def __init__(self, input_dim, output_dim, activation):
        super(WarmupLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation_mode = activation

    def forward(self, x):
        return activation_fn(self.linear(x), self.activation_mode)


def train_and_evaluate_FLAiR_drvfl(trainX, trainY, testX, testY, option, warmup_epochs=5, lr=0.001):
    torch.manual_seed(0)
    np.random.seed(0)

    start_train = time.time()
    N, L, C = option['N'], option['L'], option['C']
    activation = option['activation']
    scale = option['scale']
    renormal = option['renormal']
    norm_type = option['normal_type']

    n_samples, n_features = trainX.shape
    trainX_torch = torch.tensor(trainX, dtype=torch.float32)
    trainY_torch = torch.tensor(trainY, dtype=torch.float32)

    # Initialize hidden layers and optimizers
    layers = []
    optimizers = []
    for i in range(L):
        in_dim = n_features if i == 0 else N
        layer = WarmupLayer(in_dim, N, activation)
        with torch.no_grad():
            layer.linear.weight.data *= scale
            layer.linear.bias.data *= scale
        layers.append(layer)
        optimizers.append(optim.Adam(layer.parameters(), lr=lr))

    # 🔁 Warm-up using loss from final hidden layer only
    for epoch in range(warmup_epochs):
        A = trainX_torch
        for i in range(L):
            A = layers[i](A)
        loss = torch.mean(A ** 2)  # Only final layer activation used

        for opt in optimizers:
            opt.zero_grad()
        loss.backward()
        for opt in optimizers:
            opt.step()

    # 🔒 Freeze all layers after warmup
    for layer in layers:
        for param in layer.parameters():
            param.requires_grad = False

    # 🔄 Forward pass for training
    A_merge = [trainX_torch]
    A = trainX_torch
    mu, sigma = [], []
    for i in range(L):
        A = layers[i](A)
        if renormal:
            if norm_type == 0:
                m, s = A.mean(dim=0), A.std(dim=0) + 1e-8
                A = (A - m) / s
                mu.append(m.numpy())
                sigma.append(s.numpy())
            else:
                mu.append(None)
                sigma.append(None)
        else:
            mu.append(None)
            sigma.append(None)
        A_merge.append(A)

    A_final = torch.cat(A_merge + [torch.ones((n_samples, 1))], dim=1)
    beta = l2_weights(A_final.numpy(), trainY, C)
    pred_train = A_final @ torch.tensor(beta, dtype=torch.float32)
    pred_probs = torch.softmax(pred_train, dim=1)
    pred_labels = torch.argmax(pred_probs, axis=1).numpy()
    true_labels = np.argmax(trainY, axis=1)
    EVAL_Train = evaluate(true_labels, pred_labels)
    end_train = time.time()

    # 🔄 Forward pass for testing
    start_test = time.time()
    testX_torch = torch.tensor(testX, dtype=torch.float32)
    A_merge = [testX_torch]
    A = testX_torch
    for i in range(L):
        A = layers[i](A)
        if renormal and norm_type == 0 and mu[i] is not None:
            m = torch.tensor(mu[i])
            s = torch.tensor(sigma[i])
            A = (A - m) / s
        A_merge.append(A)

    A_final_test = torch.cat(A_merge + [torch.ones((testX.shape[0], 1))], dim=1)
    pred_test = A_final_test @ torch.tensor(beta, dtype=torch.float32)
    pred_probs_test = torch.softmax(pred_test, dim=1)
    pred_labels_test = torch.argmax(pred_probs_test, axis=1).numpy()
    true_labels_test = np.argmax(testY, axis=1)
    EVAL_Test = evaluate(true_labels_test, pred_labels_test)
    end_test = time.time()

    training_time = end_train - start_train
    testing_time = end_test - start_test

    return None, EVAL_Train, EVAL_Test, training_time, testing_time

