import torch
import torch.nn as nn
import torch.nn.functional as F




class Linear(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes=10, init_method='xavier', seed=42):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.h = nn.Linear(hidden_size, num_classes, bias=False)
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        x = self.fc1(x)
        x = self.h(x)
        return x



class DoubleHeadLinear(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes_per_head=5, init_method='xavier', seed=42):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.h0 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.h1 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.selected_head = 0
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward_h0(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        x = self.fc1(x)
        x = self.h0(x)
        return x
    
    def forward_h1(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        x = self.fc1(x)
        x = self.h1(x)
        return x
    
    def forward(self,x):
        if self.selected_head == 0:
            return self.forward_h0(x)
        if self.selected_head == 1:
            return self.forward_h1(x)
        
    def update_head(self,new_head:int)->None:
        assert new_head in [0,1]
        self.selected_head = new_head
        print(f"model head set to {self.selected_head}")










class SimpleCNN(nn.Module):
    def __init__(self, num_channels, num_classes=10, init_method='xavier', seed=42):
        super(SimpleCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(num_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)  # Input: 3x32x32, Output: 32x32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # Output: 64x32x32
        
        self.pool = nn.MaxPool2d(2, 2)  
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))
        self.fc1 = nn.Linear(64 * 8 * 8, num_classes)  # Flattened: 64 * 8 * 8 -> num_classes (e.g., 10 for CIFAR-10)
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)

    def forward(self, x):
        z1 = self.conv1(x)
        x = F.relu(z1)
        z2 = self.conv2(x)
        x = F.relu(z2)
        x = self.adaptive_pool(x)
        x = x.view(-1, 64 * 8 * 8)  
        x = self.fc1(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations



class DoubleHeadCNN(nn.Module):
    def __init__(self, num_channels, num_classes=10, init_method='xavier', seed=42):
        super(DoubleHeadCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)  # Input: 3x32x32, Output: 32x32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # Output: 64x32x32
        self.pool = nn.MaxPool2d(2, 2)  
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))
        self.h0 = nn.Linear(64 * 8 * 8, num_classes)  # Flattened: 64 * 8 * 8 -> num_classes (e.g., 10 for CIFAR-10)
        self.h1 = nn.Linear(64 * 8 * 8, num_classes)
        self.selected_head = 0
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward_h0(self, x):
        z1 = self.conv1(x)
        x = F.relu(z1)
        z2 = self.conv2(x)
        x = F.relu(z2)
        x = self.adaptive_pool(x)
        x = x.view(-1, 64 * 8 * 8)  
        x = self.h0(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations
    
    def forward_h1(self, x):
        z1 = self.conv1(x)
        x = F.relu(z1)
        z2 = self.conv2(x)
        x = F.relu(z2)
        x = self.adaptive_pool(x)
        x = x.view(-1, 64 * 8 * 8)  
        x = self.h1(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations
    
    def forward(self,x):
        if self.selected_head == 0:
            return self.forward_h0(x)
        if self.selected_head == 1:
            return self.forward_h1(x)
        
    def update_head(self,new_head:int)->None:
        assert new_head in [0,1]
        self.selected_head = new_head
        print(f"model head set to {self.selected_head}")



class SimpleMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes=10, init_method='xavier', seed=42, activation='relu'):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.h = nn.Linear(hidden_size, num_classes, bias=False)
        
        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        x = self.h(x)
        
        pre_activations = {
            'layer_1': z1 > 0
        }
        
        return x, pre_activations



class DoubleHeadMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes_per_head=5, init_method='xavier', seed=42, activation='relu'):
        super(DoubleHeadMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.h0 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.h1 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.selected_head = 0

        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward_h0(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        x = self.h0(x)
        
        pre_activations = {
            'layer_1': z1 > 0
        }
        
        return x, pre_activations
    
    def forward_h1(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        x = self.h1(x)
        
        pre_activations = {
            'layer_1': z1 > 0
        }
        
        return x, pre_activations
    
    def forward(self,x):
        if self.selected_head == 0:
            return self.forward_h0(x)
        if self.selected_head == 1:
            return self.forward_h1(x)
        
    def update_head(self,new_head:int)->None:
        assert new_head in [0,1]
        self.selected_head = new_head
        print(f"model head set to {self.selected_head}")
        
        
        

class TwoLayerMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes=5, init_method='xavier', seed=42, activation='relu'):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.h = nn.Linear(hidden_size, num_classes, bias=False)

        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        x = self.h(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations



class DoubleHeadTwoLayerMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes_per_head=5, init_method='xavier', seed=42, activation='relu'):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.h0 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.h1 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.selected_head = 0

        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward_h0(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        x = self.h0(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations
    
    def forward_h1(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        x = self.h1(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0
        }
        
        return x, pre_activations
    
    def forward(self,x):
        if self.selected_head == 0:
            return self.forward_h0(x)
        if self.selected_head == 1:
            return self.forward_h1(x)
        
    def update_head(self,new_head:int)->None:
        assert new_head in [0,1]
        self.selected_head = new_head
        print(f"model head set to {self.selected_head}")









class ThreeLayerMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, hidden_size_2=100, num_classes=5, init_method='xavier', seed=42, activation='relu'):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size_2, bias=False)
        self.fc3 = nn.Linear(hidden_size_2, hidden_size, bias=False)
        self.h = nn.Linear(hidden_size, num_classes, bias=False)

        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        z3 = self.fc3(x)
        x = self.activation_fn(z3)
        x = self.h(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0,
            'layer_3': z3 > 0
        }
        
        return x, pre_activations



class DoubleHeadThreeLayerMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, hidden_size_2=100, num_classes_per_head=5, init_method='xavier', seed=42, activation='relu'):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size_2, bias=False)
        self.fc3 = nn.Linear(hidden_size_2, hidden_size, bias=False)
        self.h0 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.h1 = nn.Linear(hidden_size, num_classes_per_head, bias=False)
        self.selected_head = 0

        self.activation = activation
        print(f"activation function: {self.activation}")
        if self.activation == 'relu':
            self.activation_fn = torch.relu
        elif self.activation == 'gelu':
            self.activation_fn = torch.nn.functional.gelu
        elif self.activation == 'elu':
            self.activation_fn = torch.nn.functional.elu
        elif self.activation == 'leaky_relu':
            self.activation_fn = torch.nn.functional.leaky_relu
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")
        
        # Initialize weights
        # initialize_weights(self, seed=seed, init_method=init_method)
        
    def forward_h0(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)        
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        z3 = self.fc3(x)
        x = self.activation_fn(z3)
        x = self.h0(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0,
            'layer_3': z3 > 0
        }
        
        return x, pre_activations
    
    def forward_h1(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (for MNIST: 28x28 -> 784)
        z1 = self.fc1(x)
        x = self.activation_fn(z1)
        z2 = self.fc2(x)
        x = self.activation_fn(z2)
        z3 = self.fc3(x)
        x = self.activation_fn(z3)
        x = self.h1(x)
        
        pre_activations = {
            'layer_1': z1 > 0,
            'layer_2': z2 > 0,
            'layer_3': z3 > 0
        }
        
        return x, pre_activations
    
    def forward(self,x):
        if self.selected_head == 0:
            return self.forward_h0(x)
        if self.selected_head == 1:
            return self.forward_h1(x)
        
    def update_head(self,new_head:int)->None:
        assert new_head in [0,1]
        self.selected_head = new_head
        print(f"model head set to {self.selected_head}")







def initialize_weights(model: nn.Module, seed: int, init_method: str='xavier',a: float=1.0)->None:
    """
    Initialize weights to the same random seed
    Args:
        model (nn.Module): model
        seed (int, optional): random seed.
        init_method (str, optional): initialization method. Options: 'xavier', 'kaiming', 'uniform'. Defaults to 'xavier'.
    """
    torch.manual_seed(seed)
    
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            # Initialize weights
            if init_method == 'xavier':
                nn.init.xavier_uniform_(m.weight)
            elif init_method == 'kaiming':
                nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            elif init_method == 'uniform':
                nn.init.uniform_(m.weight, -a, a)
            else:
                raise ValueError(f"Unsupported initialization method: {init_method}. Use 'xavier', 'kaiming', or 'uniform'.")
            
            # Initialize bias (if present)
            if m.bias is not None:
                nn.init.zeros_(m.bias)






def get_model(model_query: str, kwargs: dict):
    """
    Return initialized model according to query and kwargs
    model_query opt: available: Linear, "DoubleHeadLinear", SimpleMLP, DoubleHeadMLP
    Args:
        model_query (str): model to query, 
        kwargs (dict): _description_

    Returns:
        _type_: _description_
    """
    
    
    model_db = {
        "Linear"          : Linear,
        "SimpleMLP"       : SimpleMLP,
        "TwoLayerMLP"   : TwoLayerMLP,
        "ThreeLayerMLP" : ThreeLayerMLP,
        "DoubleHeadLinear": DoubleHeadLinear,
        "DoubleHeadMLP"   : DoubleHeadMLP,
        "DoubleHeadTwoLayerMLP": DoubleHeadTwoLayerMLP,
        "DoubleHeadThreeLayerMLP": DoubleHeadThreeLayerMLP,
    }
    
    assert model_query in model_db.keys(), AssertionError("Invalid model query!")
    
    return model_db[model_query](**kwargs)







def main():
    X = torch.randn(64,1,28,28)
    
    # Test SimpleCNN
    model_cnn = SimpleCNN(num_channels=1, num_classes=10)
    output_cnn, pre_activations_cnn = model_cnn(X)
    print(f"SimpleCNN output shape: {output_cnn.shape}")
    print(f"SimpleCNN pre-activations keys: {list(pre_activations_cnn.keys())}")
    print(f"SimpleCNN layer_1 pre-activation shape: {pre_activations_cnn['layer_1'].shape}")
    print(f"SimpleCNN layer_2 pre-activation shape: {pre_activations_cnn['layer_2'].shape}")
    
    # Test SimpleMLP
    model_fnn = SimpleMLP()
    output_mlp, pre_activations_mlp = model_fnn(X)
    print(f"\nSimpleMLP output shape: {output_mlp.shape}")
    print(f"SimpleMLP pre-activations keys: {list(pre_activations_mlp.keys())}")
    print(f"SimpleMLP layer_1 pre-activation shape: {pre_activations_mlp['layer_1'].shape}")
    
    # Test TwoLayerMLP
    model_two = TwoLayerMLP()
    output_two, pre_activations_two = model_two(X)
    print(f"\nTwoLayerMLP output shape: {output_two.shape}")
    print(f"TwoLayerMLP pre-activations keys: {list(pre_activations_two.keys())}")
    print(f"TwoLayerMLP layer_1 pre-activation shape: {pre_activations_two['layer_1'].shape}")
    print(f"TwoLayerMLP layer_2 pre-activation shape: {pre_activations_two['layer_2'].shape}")
    
    # Test ThreeLayerMLP
    model_three = ThreeLayerMLP()
    output_three, pre_activations_three = model_three(X)
    print(f"\nThreeLayerMLP output shape: {output_three.shape}")
    print(f"ThreeLayerMLP pre-activations keys: {list(pre_activations_three.keys())}")
    print(f"ThreeLayerMLP layer_1 pre-activation shape: {pre_activations_three['layer_1'].shape}")
    print(f"ThreeLayerMLP layer_2 pre-activation shape: {pre_activations_three['layer_2'].shape}")
    print(f"ThreeLayerMLP layer_3 pre-activation shape: {pre_activations_three['layer_3'].shape}")
    
    # Test DoubleHeadMLP
    model_double = DoubleHeadMLP()
    output_double, pre_activations_double = model_double(X)
    print(f"\nDoubleHeadMLP output shape: {output_double.shape}")
    print(f"DoubleHeadMLP pre-activations keys: {list(pre_activations_double.keys())}")
    print(f"DoubleHeadMLP layer_1 pre-activation shape: {pre_activations_double['layer_1'].shape}")
    
    print("\nAll tests completed successfully!")
    
    
if __name__ == "__main__":
    main()
