""" models """

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import math
        
class MLP(nn.Module):
    def __init__(self, dims, activation = nn.ReLU(), seed = 42):
        super(MLP, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        torch.manual_seed(seed)
        self.activation = activation
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1], bias = False))
            if i < len(dims) - 2: 
                layers.append(self.activation) 
     
        self.network = nn.Sequential(*layers)
        torch.random.set_rng_state(rnd_state)
        
    def forward(self, X):
        acts = []
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                acts.append(X.data.T)
            X = layer(X)
            
        return X, acts
    
    

class CNN(nn.Module):
    def __init__(self, in_channels, num_classes, activation, seed = 42):
        super(CNN, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        self.seed = seed
        torch.manual_seed(self.seed)
        self.in_channels = in_channels
        self.channel_ls = [self.in_channels] + [32, 64, 128]
        
        self.in_channels = in_channels
        
        self.conv_layer1 = nn.Conv2d(in_channels=self.in_channels, out_channels=32, kernel_size=3, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(32)
        self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(64)
        self.max_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(128)
        self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=3)
        
        # self.conv_layer4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding = 1, bias = False)
        # self.bn4 = nn.BatchNorm2d(32, track_running_stats=False)
        # self.max_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # self.conv_layer5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding = 1, bias = False)
        # self.bn5 = nn.BatchNorm2d(64, track_running_stats=False)
        # self.max_pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # self.conv_layer6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1, bias = False)
        # self.bn6 = nn.BatchNorm2d(64, track_running_stats=False)
        # self.max_pool6 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.relu = activation
        
        # self.fc1 = None  # Initialize to None, will be defined after computing the flattened size
        self.fc1 = nn.Linear(512, 1024, bias = False)
        self.fc2 = nn.Linear(1024, num_classes, bias = False)
        torch.random.set_rng_state(rnd_state)
    def forward(self, X):
        activations = []
        maps = []
        padding = (1,1,1,1)
        
        # maps.append(X.size(-1))
        # activations.append(X.detach())  
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        # Pass through convolutional layers
        X = self.relu(self.bn1(self.conv_layer1(X)))
        # X = self.relu(self.conv_layer1(X))
        X = self.max_pool1(X)
        
        # maps.append(X.size(-1))
        # activations.append(X.detach())   
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        # res = X
        X = self.relu(self.bn2(self.conv_layer2(X)))
        # X = X + res
        # X = self.relu(self.conv_layer2(X))
        X = self.max_pool2(X)

        # maps.append(X.size(-1))
        # activations.append(X.detach())      
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.relu(self.bn3(self.conv_layer3(X)))
        # X = self.relu(self.conv_layer3(X))
        X = self.max_pool3(X)
        
        # # maps.append(X.size(-1))
        # # activations.append(X.detach())
        # maps.append(X.size(-1))
        # activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        # res = X
        # X = self.relu(self.bn4(self.conv_layer4(X)))
        # X = X + res
        # X = self.max_pool4(X)

        # # maps.append(X.size(-1))
        # # activations.append(X.detach())
        # maps.append(X.size(-1))
        # activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        # X = self.relu(self.bn5(self.conv_layer5(X)))
        # X = self.max_pool5(X)

        # # maps.append(X.size(-1))
        # # activations.append(X.detach())
        # maps.append(X.size(-1))
        # activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        # res = X
        # X = self.relu(self.bn6(self.conv_layer6(X)))
        # X = X + res
        # X = self.max_pool6(X)
        
        # Flatten the output for the fully connected layer
        X = X.view(X.size(0), -1)
        activations.append(X.detach().transpose(0,1))
        # Define the first fully connected layer dynamically based on the flattened size
        # if self.fc1 is None:
        #     rnd_state = torch.random.get_rng_state() 
        #     torch.manual_seed(self.seed)
        #     self.fc1 = nn.Linear(X.size(1), 128, bias = False).to(X.device)
        #     self.add_module('fc1', self.fc1)  # Register the module in the model
        #     torch.random.set_rng_state(rnd_state)
            
        X = self.relu(self.fc1(X))
        activations.append(X.detach().transpose(0,1))
        X = self.fc2(X)

        
        return X, activations, maps, self.channel_ls



class Resnet(nn.Module):
    def __init__(self, in_channels, num_classes, activation, seed = 42):
        super(Resnet, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        self.seed = seed
        torch.manual_seed(self.seed)
        self.in_channels = in_channels
        self.channel_ls = [self.in_channels] + [64, 64, 64, 128, 128, 128, 256, 256, 256]
        
        self.in_channels = in_channels
        
        
        self.conv_layer1 = nn.Conv2d(in_channels=self.in_channels, out_channels=64, kernel_size=3, padding = 1, bias = False, stride = 3)
        self.bn1 = nn.BatchNorm2d(64)
        # self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=3)
        
        self.conv_layer2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn2 = nn.BatchNorm2d(64)
        # self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=3)
        
        self.conv_layer3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn3 = nn.BatchNorm2d(64)
        # self.max_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding = 1, bias = False, stride = 3)
        self.bn4 = nn.BatchNorm2d(128)
        # self.max_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_layer5 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn5 = nn.BatchNorm2d(128)
        # self.max_pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn6 = nn.BatchNorm2d(128)
        # self.max_pool6 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer7 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding = 1, bias = False, stride = 3)
        self.bn7 = nn.BatchNorm2d(256)
        
        self.conv_layer8 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn8 = nn.BatchNorm2d(256)
        
        self.conv_layer9 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding = 1, bias = False, stride = 1)
        self.bn9 = nn.BatchNorm2d(256)
        
        self.relu = activation
        # self.drop1=torch.nn.Dropout(0.2)
        # self.drop2=torch.nn.Dropout(0.5)
        
        # self.fc1 = None  # Initialize to None, will be defined after computing the flattened size
        self.fc1 = nn.Linear(256, 1024, bias = False)
        self.fc2 = nn.Linear(1024, num_classes, bias = False)
        torch.random.set_rng_state(rnd_state)
    def forward(self, X):
        activations = []
        maps = []
        padding = (1,1,1,1)
        
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.relu(self.bn1(self.conv_layer1(X)))
          
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        res = X
        X = self.relu(self.bn2(self.conv_layer2(X)))
        
     
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.bn3(self.conv_layer3(X))
        X = X + res
        X = self.relu(X)
        
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.relu(self.bn4(self.conv_layer4(X)))

        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        res = X
        X = self.relu(self.bn5(self.conv_layer5(X)))

        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.bn6(self.conv_layer6(X))
        X = X + res
        X = self.relu(X)        
        
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.relu(self.bn7(self.conv_layer7(X)))
        
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        res = X
        X = self.relu(self.bn8(self.conv_layer8(X)))
        
        maps.append(X.size(-1))
        activations.append(F.pad(X.detach(), padding, mode='constant', value=0))  
        X = self.bn9(self.conv_layer9(X))
        X = X + res
        X = self.relu(X)
        
        X = F.avg_pool2d(X, 9)
        # Flatten the output for the fully connected layer
        X = X.view(X.size(0), -1)
        activations.append(X.detach().transpose(0,1))

            
        X = self.relu(self.fc1(X))
        activations.append(X.detach().transpose(0,1))
        X = self.fc2(X)
        

        
        return X, activations, maps, self.channel_ls
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x.transpose(0,1)
        x = x + self.pe[:x.size(0)].to(x.device)
        return x.transpose(0,1)

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, mask = False):
        super(CustomMultiheadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.mask = mask

        # Define a single linear layer for the combined QKV projection
        self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3, bias=False)

        # Define the output projection layer
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, query, key, value, activations=None, attention_mask = None):
        batch_size = query.size(0)

        # Combined QKV projection
        qkv = self.qkv_linear(query)

        # Split the QKV matrix into separate Q, K, V matrices
        qkv = qkv.view(batch_size, -1, self.num_heads, 3 * self.head_dim).transpose(1, 2)
        q, k, v = qkv.chunk(3, dim=-1)

        # Generate a standard causal mask (lower triangular matrix)
        if self.mask:
            seq_len = q.size(-2)
            causal_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1).to(query.device)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = torch.where(attention_mask == 1, 0.0, float('-inf'))


        
        # Scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if self.mask:
            attn_weights += causal_mask.unsqueeze(0).unsqueeze(0)
        if attention_mask is not None:
            attn_weights += attention_mask
        
        attn_weights = F.softmax(attn_weights, dim=-1)

        attn_output = torch.matmul(attn_weights, v)

        # Concatenate the heads and reshape
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

        # Record the activation before the output projection matrix
        if activations is not None:
            activations.append(attn_output.detach().reshape(-1,attn_output.shape[-1]).transpose(0,1))

        # Apply the output projection
        attn_output = self.out_proj(attn_output)

        return attn_output
        
class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward, activation=nn.ReLU(), mask = None):
        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = CustomMultiheadAttention(d_model, num_heads, mask)
        
        # Feedforward network without bias in linear layers
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
        self.activation = activation
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
        
        # LayerNorm without learnable parameters
        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)

    def forward(self, src, activations, attention_mask = None):
        # Record activation before self-attention
        activations.append(src.detach().reshape(-1,src.shape[-1]).transpose(0,1))

        # Self-attention
        src2 = self.self_attn(src, src, src, activations, attention_mask)
        src = self.norm1(src + src2)  # Apply LayerNorm after self-attention

        # Record activation before the first linear layer
        activations.append(src.detach().reshape(-1,src.shape[-1]).transpose(0,1))

        # Feedforward network
        src2 = self.linear1(src)

        # Record activation before the second linear layer
        
        src2 = self.activation(src2)
        
        activations.append(src2.detach().reshape(-1,src2.shape[-1]).transpose(0,1))  # Correctly detaching src2 before the second linear layer
        src2 = self.linear2(src2)
        src = self.norm2(src + src2)  # Apply LayerNorm after feedforward network

        return src

class Transformer(nn.Module):
    def __init__(self, input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=nn.ReLU(), max_len=5000, seed = 42):
        super(Transformer, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        torch.manual_seed(seed)
        self.linear_embedding = nn.Linear(input_dim, d_model, bias = False)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                dim_feedforward=dim_feedforward,
                activation=activation
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)  # Apply LayerNorm at the end
        self.classification_head = nn.Linear(d_model, num_classes, bias=False)   # Final classification head without bias
        torch.random.set_rng_state(rnd_state)
        
    def forward(self, X):
        activations = []

        # Record activation before the linear embedding layer
        activations.append(X.detach().reshape(-1,X.shape[-1]).transpose(0,1))

        X = self.linear_embedding(X)  # Apply linear layer for embedding
        X = self.positional_encoding(X)  # Apply positional encoding

        # Record activations after the embedding and positional encoding
        for layer in self.layers:
            X = layer(X, activations)
        
        # X = self.norm(X).mean(dim = 1)  # Apply layer normalization at the end
        X = X.mean(dim = 1)
        # Record activation before the classification head
        activations.append(X.detach().transpose(0,1))

        # Apply the classification head
        logits = self.classification_head(X)

        return logits, activations

class Transformer_Regression(nn.Module):
    def __init__(self, input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=nn.ReLU(), max_len=5000, seed = 42):
        super(Transformer, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        torch.manual_seed(seed)
        self.linear_embedding = nn.Linear(input_dim, d_model, bias = False)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                dim_feedforward=dim_feedforward,
                activation=activation
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)  # Apply LayerNorm at the end
        self.classification_head = nn.Linear(d_model, num_classes, bias=False)   # Final classification head without bias
        torch.random.set_rng_state(rnd_state)
        
    def forward(self, X):
        activations = []

        # Record activation before the linear embedding layer
        activations.append(X.detach().reshape(-1,X.shape[-1]).transpose(0,1))

        X = self.linear_embedding(X)  # Apply linear layer for embedding
        X = self.positional_encoding(X)  # Apply positional encoding

        # Record activations after the embedding and positional encoding
        for layer in self.layers:
            X = layer(X, activations)
        
        # X = self.norm(X).mean(dim = 1)  # Apply layer normalization at the end
        # X = X.mean(dim = 1)
        # Record activation before the classification head
        activations.append(X.detach().reshape(-1,X.shape[-1]).transpose(0,1))

        # Apply the classification head
        logits = self.classification_head(X)

        return logits, activations

class Transformer_LM(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=nn.ReLU(), max_len=5000, seed = 42):
        super(Transformer_LM, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        torch.manual_seed(seed)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                dim_feedforward=dim_feedforward,
                activation=activation,
                mask = True
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)  # Apply LayerNorm at the end
        self.classification_head = nn.Linear(d_model, vocab_size, bias=False)   # Final classification head without bias
        torch.random.set_rng_state(rnd_state)
        
    def forward(self, X):
        activations = []

        # Record activation before the linear embedding layer

        X = self.embedding(X)  # Apply linear layer for embedding
        X = self.positional_encoding(X)  # Apply positional encoding

        # Record activations after the embedding and positional encoding
        for layer in self.layers:
            X = layer(X, activations)
        
        # X = self.norm(X)  # Apply layer normalization at the end

        # Record activation before the classification head
        activations.append(X.detach().reshape(-1,X.shape[-1]).transpose(0,1))

        # Apply the classification head
        logits = self.classification_head(X)

        return logits, activations


class Transformer_Classification_LM(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=nn.ReLU(), max_len=5000, seed = 42):
        super(Transformer_Classification_LM, self).__init__()
        rnd_state = torch.random.get_rng_state() 
        torch.manual_seed(seed)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                dim_feedforward=dim_feedforward,
                activation=activation,
                mask = False
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)  # Apply LayerNorm at the end
        self.classification_head = nn.Linear(d_model, num_classes, bias=False)   # Final classification head without bias
        torch.random.set_rng_state(rnd_state)
        
    def forward(self, X, attention_mask):
        activations = []

        # Record activation before the linear embedding layer

        X = self.embedding(X)  # Apply linear layer for embedding
        X = self.positional_encoding(X)  # Apply positional encoding

        # Record activations after the embedding and positional encoding
        for layer in self.layers:
            X = layer(X, activations, attention_mask)
        
        # X = self.norm(X)  # Apply layer normalization at the end
        X = X.mean(dim = 1)
        # Record activation before the classification head
        activations.append(X.detach().reshape(-1,X.shape[-1]).transpose(0,1))

        # Apply the classification head
        logits = self.classification_head(X)

        return logits, activations