import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from transformers import GPT2LMHeadModel, GPT2Config
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

# GPT-2
# --------------------------------------------------------------------------
class GPT2(torch.nn.Module):
    def __init__(self, d_model, n_layers, n_heads, vocab_size, max_len):
        super(GPT2, self).__init__()
        self.config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=d_model,
            n_layer=n_layers,
            n_head=n_heads,
            n_positions=max_len
        )
        self.transformer = GPT2LMHeadModel(self.config)

    def forward(self, x, targets=None):
        # x: [batch_size, src_len]
        logits = self.transformer(x).logits
        # logits: [batch_size, src_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
        return logits, loss

# Mamba
# --------------------------------------------------------------------------
class MAMBA(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, d_state=16, d_conv=4, expand=2):
        super(MAMBA, self).__init__()
        self.config = MambaConfig(
            d_model=d_model,
            n_layer=n_layers,
            vocab_size=vocab_size,
            ssm_cfg={'d_state': d_state, 'd_conv': d_conv, 'expand': expand}
        )
        self.model = MambaLMHeadModel(config=self.config)

    def forward(self, x, targets=None):
        # x: [batch_size, seq_len]
        logits = self.model(x).logits
        # logits: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss

# Mamba-2
# --------------------------------------------------------------------------
class MAMBA2(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, d_state=16, d_conv=4, expand=2):
        super(MAMBA2, self).__init__()
        self.config = MambaConfig(
            d_model=d_model,
            n_layer=n_layers,
            vocab_size=vocab_size,
            ssm_cfg={'d_state': d_state, 'd_conv': d_conv, 'expand': expand, 'layer': 'Mamba2'}
        )
        self.model = MambaLMHeadModel(config=self.config)

    def forward(self, x, targets=None):
        # x: [batch_size, seq_len]
        logits = self.model(x).logits
        # logits: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss


if __name__ == '__main__':

    def print_model(model):
        print("model_size: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
        print("Model Architecture:")
        print(model)
        print("\nParameter Count for Each Part:")
        total_params = 0
        for name, param in model.named_parameters():
            num_params = param.numel()
            total_params += num_params
            print(f"{name}: {param.shape} -> {num_params} parameters")
        print(f"\nTotal Parameters: {total_params}")

 
    max_len = 1024
    vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
    data_root = [['S i want a beer E', 'i want a beer E P'], ['S want a beer i E', 'want a beer i E P']]
    vocab_size = len(vocab)

    size = 'mini'
    # size = 'small'
    # size = 'medium'
    # size = 'large'
    if size == 'mini':
        n_layers = 12
        d_model = 512
        n_heads = 8
        d_head = 64
    elif size == 'small':
        n_layers = 12
        d_model = 768
        n_heads = 12
        d_head = 64
    elif size == 'medium':
        n_layers = 24
        d_model = 1024
        n_heads = 16
        d_head = 64
    elif size == 'large':
        n_layers = 24
        d_model = 1536
        n_heads = 16
        d_head = 96

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")


    model = GPT2(d_model, n_layers, n_heads, vocab_size, max_len).to(device)
    # model = MAMBA(d_model, n_layers, vocab_size).to(device)
    # model = MAMBA2(d_model, n_layers, vocab_size).to(device)
 
    print_model(model)

    input_batch = []
    target_batch = []
    for sentence in data_root:
        input_batch.append([vocab[n] for n in sentence[0].split()])
        target_batch.append([vocab[n] for n in sentence[1].split()])
    inputs, targets = torch.LongTensor(input_batch).to(device), torch.LongTensor(target_batch).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(20):
        optimizer.zero_grad()
        outputs, loss = model(inputs, targets=targets)
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
        loss.backward()
        optimizer.step()