import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import numpy as np
from typing import Dict, List, Tuple, Optional


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.embed = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
        )
        self.body = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
        )
        self.head = nn.Sequential(
            nn.Linear(512, 10) # 10 classes for CIFAR-10
        )

    def forward(self, x):
        x = self.embed(x)
        x = self.body(x)
        x = self.head(x)
        return x


class SimpleCNNBinClass(nn.Module):
    def __init__(self):
        super(SimpleCNNBinClass, self).__init__()
        self.embed = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
        )
        self.body = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
        )
        self.head = nn.Sequential(
            nn.Linear(512, 2) # 10 classes for CIFAR-10
        )

    def forward(self, x):
        x = self.embed(x)
        x = self.body(x)
        x = self.head(x)
        return x

class ResNet18_32x32(nn.Module):
    def __init__(self):
        super(ResNet18_32x32, self).__init__()
        base_model = torchvision.models.resnet18(weights=None, num_classes=10)

        # 1. Embed:
        # The original ResNet's first conv layer is not ideal for CIFAR-10's 32x32 images.
        # We replace it with a smaller kernel and stride.
        self.embed = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        )

        # 2. Body:
        # We take the rest of the ResNet-18 body, from the first batch norm
        # up to the final average pooling layer.
        self.body = nn.Sequential(
            base_model.bn1,
            base_model.relu,
            # The original maxpool is kept. It reduces 32x32 to 16x16.
            # base_model.maxpool, 
            base_model.layer1,
            base_model.layer2,
            base_model.layer3,
            base_model.layer4,
            base_model.avgpool,
            nn.Flatten() # Flatten the output for the linear layer
        )

        # 3. Head:
        # The classifier head. The number of input features for ResNet18's
        # fully connected layer is 512.
        self.head = nn.Sequential(
            nn.Linear(512, 10)
        )

    def forward(self, x):
        """
        Forward pass follows the embed -> body -> head structure.
        """
        x = self.embed(x)
        x = self.body(x)
        x = self.head(x)
        return x
    
class ResNet18_32x32_BinClass(nn.Module):
    def __init__(self):
        super(ResNet18_32x32_BinClass, self).__init__()
        base_model = torchvision.models.resnet18(weights=None, num_classes=10)

        self.embed = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        )

        self.body = nn.Sequential(
            base_model.bn1,
            base_model.relu,
            base_model.layer1,
            base_model.layer2,
            base_model.layer3,
            base_model.layer4,
            base_model.avgpool,
            nn.Flatten()
        )

        self.head = nn.Sequential(
            nn.Linear(512, 2)
        )

    def forward(self, x):
        x = self.embed(x)
        x = self.body(x)
        x = self.head(x)
        return x


class CharacterLSTM(nn.Module):
    def __init__(
        self,
        vocab_size=256,
        embedding_dim=64,
        hidden_dim=256,
        num_layers=2,
        dropout_rate=0.2,
        bidirectional=False
    ):
        """
        Args:
            vocab_size: Size of character vocabulary (default 256 for extended ASCII)
            embedding_dim: Dimension of character embeddings
            hidden_dim: Hidden dimension of LSTM layers
            num_layers: Number of LSTM layers
            dropout_rate: Dropout rate for regularization
            bidirectional: Whether to use bidirectional LSTM
        """
        super(CharacterLSTM, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
            padding_idx=0
        )
        nn.init.xavier_uniform_(self.embedding.weight)
        
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_rate if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        
        self.dropout = nn.Dropout(dropout_rate)
        
        self.output_projection = nn.Linear(
            hidden_dim * self.num_directions,
            vocab_size
        )
        
        self.layer_norm = nn.LayerNorm(hidden_dim * self.num_directions)
        
        self._init_lstm_weights()
    
    def _init_lstm_weights(self):
        for name, param in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                n = param.size(0)
                param.data[n//4:n//2].fill_(1)
    
    def forward(self, input_ids, hidden=None):
        """
        Args:
            input_ids: Input character indices [batch_size, seq_length]
            hidden: Optional initial hidden state tuple (h_0, c_0)
        
        Returns:
            logits: Output logits [batch_size, seq_length, vocab_size]
            hidden: Final hidden state tuple (h_n, c_n)
        """
        batch_size, seq_length = input_ids.size()
        
        if hidden is None:
            hidden = self.init_hidden(batch_size, input_ids.device)
        
        embedded = self.embedding(input_ids)  # [batch_size, seq_length, embedding_dim]
        embedded = self.dropout(embedded)
        
        lstm_out, hidden = self.lstm(embedded, hidden)
        # lstm_out: [batch_size, seq_length, hidden_dim * num_directions]
        
        lstm_out = self.layer_norm(lstm_out)
        
        lstm_out = self.dropout(lstm_out)
        
        logits = self.output_projection(lstm_out)
        # logits: [batch_size, seq_length, vocab_size]
        
        return logits, hidden
    
    def init_hidden(self, batch_size, device):
        """
        Args:
            batch_size: Batch size
            device: Device to create tensors on
        
        Returns:
            Initial hidden state tuple (h_0, c_0)
        """
        h_0 = torch.zeros(
            self.num_layers * self.num_directions,
            batch_size,
            self.hidden_dim,
            device=device
        )
        c_0 = torch.zeros(
            self.num_layers * self.num_directions,
            batch_size,
            self.hidden_dim,
            device=device
        )
        return (h_0, c_0)
    

class CharacterTransformer(nn.Module):
    def __init__(self, vocab_size=256, embed_dim=128, num_heads=8,
                 num_layers=6, ff_dim=512, max_seq_len=256, dropout=0.2):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='relu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            norm=nn.LayerNorm(embed_dim)
        )
        
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.output_projection = nn.Linear(embed_dim, vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.token_embed.weight, mean=0, std=0.02)
        nn.init.normal_(self.pos_embed.weight, mean=0, std=0.02)
        
        nn.init.normal_(self.output_projection.weight, mean=0, std=0.02)
        nn.init.zeros_(self.output_projection.bias)
        
        for p in self.transformer.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, input_ids):
        """
        Args:
            input_ids: [batch_size, seq_len]
        
        Returns:
            logits: [batch_size, seq_len, vocab_size]
        """
        batch_size, seq_len = input_ids.shape
        
        if seq_len > self.max_seq_len:
            raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")

        causal_mask = self._generate_causal_mask(seq_len, input_ids.device)
        
        token_embeds = self.token_embed(input_ids)  # [batch_size, seq_len, embed_dim]
        
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)  # [1, seq_len]
        pos_embeds = self.pos_embed(positions)  # [1, seq_len, embed_dim]
        
        x = token_embeds + pos_embeds
        x = self.dropout(x)
        
        x = self.transformer(x, mask=causal_mask)  # [batch_size, seq_len, embed_dim]
        
        x = self.layer_norm(x)
        logits = self.output_projection(x)  # [batch_size, seq_len, vocab_size]
        
        return logits, None
    
    def _generate_causal_mask(self, sz, device):
        mask = torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)
        return mask


model_map = {
    'SimpleCNN': SimpleCNN,
    'SimpleCNNBinClass': SimpleCNNBinClass,
    'ResNet18_32x32': ResNet18_32x32,
    'ResNet18_32x32BinClass': ResNet18_32x32_BinClass
}