import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import trange
from sklearn.model_selection import train_test_split

# Residual block for ResNet architecture
class ResidualBlock(nn.Module):
    def __init__(self, dim, hidden=64, activation=nn.ReLU, norm=True, dropout=0.0):
        super().__init__()
        h = hidden or dim
        layers = [nn.Linear(dim, h)]
        if norm:
            layers.append(nn.LayerNorm(h))
        layers.append(activation())
        layers.append(nn.Linear(h, dim))
        self.net = nn.Sequential(*layers)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x):
        out = self.net(x)
        out = self.dropout(out)
        return x + out 

# Full ResNet MLP
class ResNetMLP(nn.Module):
    def __init__(self, input_dim, output_dim,
                 width=128, depth=4,
                 activation=nn.ReLU, dropout=0.0):
        super().__init__()
        self.in_layer = nn.Linear(input_dim, width)

        blocks = []
        for _ in range(depth):
            blocks.append(ResidualBlock(width, activation=activation,
                                        norm=True, dropout=dropout))
        self.blocks = nn.Sequential(*blocks)

        self.out_norm = nn.LayerNorm(width)
        self.out_layer = nn.Linear(width, output_dim)

    def forward(self, x):
        x = F.relu(self.in_layer(x))
        x = self.blocks(x)
        x = self.out_norm(x)
        return self.out_layer(x)

# Generic MFG operator class
class MFGOperator:
    def __init__(
            self,
            input_dim, 
            architecture, 
            output_dim,
            resnet=False,
            loss_function=nn.MSELoss(reduction='sum'),
            activation=nn.ReLU(),
            optimizer=optim.Adam,
            scheduler=CosineAnnealingLR,
            learning_rate=1e-4
        ):
        self.input_dim = input_dim
        self.resnet = resnet
        self.architecture = architecture
        self.output_dim = output_dim
        self.learning_rate = learning_rate
        # Replace with 'cuda' if torch.cuda.is_available() else 'cpu' if on a machine with GPU
        self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
        self.criterion = loss_function
        self.activation = activation
        if self.resnet == True:
                self.model = ResNetMLP(
                input_dim=self.input_dim,
                output_dim=self.output_dim,
                width=architecture.get('width', 128),
                depth=architecture.get('depth', 4),
                activation=nn.ReLU,
                dropout=architecture.get('dropout', 0.0)
            ).to(self.device)
        else:
            self.model = self._build_model().to(self.device)

        self.optimizer = optimizer(self.model.parameters(), lr=self.learning_rate)
        self.scheduler = scheduler

    # Build standard feedforward MLP with ReLU activation if not using ResNet
    def _build_model(self):
        layers = []
        input_dim = self.input_dim
        for hidden_units in self.architecture:
            linear = nn.Linear(input_dim, hidden_units)
            nn.init.xavier_uniform_(linear.weight)
            layers.append(linear)
            layers.append(self.activation)
            input_dim = hidden_units
        final_linear = nn.Linear(input_dim, self.output_dim)
        nn.init.xavier_uniform_(final_linear.weight)
        layers.append(final_linear)
        return nn.Sequential(*layers)

    # Split data into training and testing sets, with 20% for testing by default
    def split_data(self, X, y, test_size=0.2, batch_size=32, random_state=42):
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
        
        X_train = torch.tensor(X_train, dtype=torch.float32)
        X_test = torch.tensor(X_test, dtype=torch.float32)
        y_train = torch.tensor(y_train, dtype=torch.float32)
        y_test = torch.tensor(y_test, dtype=torch.float32)
        
        self.train_loader = torch.utils.data.DataLoader(
            dataset=torch.utils.data.TensorDataset(X_train, y_train),
            batch_size=batch_size,
            shuffle=True
        )
        self.test_loader = torch.utils.data.DataLoader(
            dataset=torch.utils.data.TensorDataset(X_test, y_test),
            batch_size=batch_size,
            shuffle=False
        )

    # Training loop, using specified optimizer and scheduler
    def train(self, epochs):
        if self.scheduler is not None:
            self.scheduler = self.scheduler(self.optimizer, T_max=epochs)
        self.model.train()
        t = trange(epochs, leave=True)
        for epoch in t:
            epoch_loss = 0.0
            for batch_X, batch_y in self.train_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()
                
            if self.scheduler is not None:
                self.scheduler.step()  # Update the learning rate using the scheduler
            t.set_description(f'Epoch {epoch + 1} / {epochs}, Train Loss: {epoch_loss / len(self.train_loader):.4f}', refresh=True)

    # Evaluate the model on the test set
    def evaluate_test(self):
        self.model.eval()
        total_loss = 0.0
        with torch.no_grad():
            for batch_X, batch_y in self.test_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                total_loss += loss.item()

        avg_loss = total_loss / len(self.test_loader)
        print(f'Test Loss: {avg_loss:.4f}')
        return avg_loss
    
    # Predict on a new sample
    def predict(self, X):
        self.model.eval()
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
            predictions = self.model(X_tensor)
            return predictions.cpu().numpy()
        
    # Save and load model state
    def save_model(self, path):
        torch.save(self.model.state_dict(), path)
        print(f'Model saved to {path}')

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        print(f'Model loaded from {path}')
    