import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import time
import copy

class FeatureNet(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        """
        Flexible feature/interaction-specific network
        Args:
            input_dim: number of input dimensions for this subnetwork
            hidden_dims: list of hidden layer dimensions (e.g., [6, 3] for input->6->3->1)
        """
        super().__init__()
        layers = []
        in_dim = input_dim
        
        # Hidden layers
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, h_dim),
                nn.ReLU()
            ])
            in_dim = h_dim
            
        # Final output layer (scalar output)
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class AdditiveModel(nn.Module):
    def __init__(self, index_list, hidden_dims, output_dim=1):
        """
        Additive model with main effects + interaction terms
        Args:
            index_list: list of lists specifying feature groups, 
                        e.g. [[0], [1], [2], [2,3]]
            hidden_dims: hidden layer sizes for each subnetwork
            output_dim: dimension of model output
        """
        super().__init__()
        self.index_list = index_list
        
        # Build one FeatureNet per group in index_list
        self.feature_nets = nn.ModuleList([
            FeatureNet(len(indices), hidden_dims) 
            for indices in index_list
        ])
        
        # Linear combiner (without bias) to sum contributions
        self.combiner = nn.Linear(len(index_list), output_dim, bias=False)
        self.hook = {}

    def forward(self, X):
        """
        Forward pass through all subnetworks
        Args:
            X: input tensor of shape [batch_size, num_features]
        """
        individual_outputs = []
        
        for indices, net in zip(self.index_list, self.feature_nets):
            # Select relevant columns (keep 2D)
            x_sub = X[:, indices]
            out = net(x_sub)  # [batch_size, 1]
            individual_outputs.append(out)
        
        combined = torch.cat(individual_outputs, dim=1)  # [batch_size, num_subnets]
        self.hook['acomp'] = combined
        return self.combiner(combined)
    
## Fair Comparison
class DNNBaseline(nn.Module):
    def __init__(self, input_dim, hidden_dims=[64, 32]):
        super(DNNBaseline, self).__init__()
        layers = []
        
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(input_dim, h_dim),
                nn.ReLU()
            ])
            input_dim = h_dim
            
        # Final output layer
        layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
    
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
def train(model, X_train, y_train, X_val, y_val, file_path, n_epochs = 10000, batch_size=64, lr=1e-2, pt = 50):
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    early_stopping = EarlyStopper(patience = pt)

    # Training loop
    n_epochs = n_epochs
    batch_size = X_train.size()[0]
    batch_start = torch.arange(0, len(X_train), batch_size)

    best_mse = float('inf')
    best_weights = None

    start_time = time.time()
    patient = 0
    # Training loop

    for epoch in range(n_epochs):
        model.train()
        
        for start in batch_start:
            X_batch = X_train[start:start+batch_size]
            y_batch = y_train[start:start+batch_size]
            
            # Forward pass
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Evaluate model on test set at the end of each epoch
        model.eval()
        with torch.no_grad():
            y_pred = model(X_val)
            val_loss = loss_fn(y_pred, y_val)
            val_loss = float(val_loss)
            scheduler.step(val_loss)
            
            if early_stopping.early_stop(val_loss):
                print(f"Early Stop at Epoch {epoch}")
                break
            
            if val_loss < best_mse:
                best_mse = val_loss
                best_weights = copy.deepcopy(model.state_dict())
            
        '''  
        if patient == pt:
            print(f"Early Stop at Epoch {epoch}")
            break
        '''
        
        if (epoch+1) % 1000 == 0:
            print(f"Epoch {epoch+1}, MSE: {val_loss}")

    end_time = time.time()
    Time_consumption = end_time - start_time
    torch.save(best_weights, file_path)

    return Time_consumption

def eval_model(model, path, X, y):
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    loss_fn = nn.MSELoss()
    with torch.no_grad():
        y_pred = model(X)
        mse = loss_fn(y, y_pred)

    return mse