import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import argparse
from tqdm.auto import tqdm
import os

# Argument parser
parser = argparse.ArgumentParser(description='Train a Sequential Model')
parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
parser.add_argument('--epochs', default=10, type=int, help='Training epochs')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--n_layers', default=4, type=int, help='Number of layers')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
args = parser.parse_args()

# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'

theta = torch.randn(10)

# Define dataset class
class SequenceDataset(data.Dataset):
    def __init__(self, num_samples, theta, sequence_length=1000):
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.magnitudes = torch.randn(num_samples, 10)
        self.data = self.magnitudes[:,0].unsqueeze(-1) @ torch.cos(3 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,1].unsqueeze(-1) @ torch.cos(5 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,2].unsqueeze(-1) @ torch.cos(7 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,3].unsqueeze(-1) @ torch.cos(11 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,4].unsqueeze(-1) @ torch.cos(13 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,5].unsqueeze(-1) @ torch.cos(17 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,6].unsqueeze(-1) @ torch.cos(19 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,7].unsqueeze(-1) @ torch.cos(23 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,8].unsqueeze(-1) @ torch.cos(29 * torch.arange(sequence_length) / 100).unsqueeze(0) + \
            self.magnitudes[:,9].unsqueeze(-1) @ torch.cos(31 * torch.arange(sequence_length) / 100).unsqueeze(0)
        self.data = self.data.unsqueeze(1)
        self.labels = self.magnitudes @ theta
        self.labels = self.labels.unsqueeze(-1)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Generate datasets
train_dataset = SequenceDataset(num_samples=100000, theta=theta)
val_dataset = SequenceDataset(num_samples=10000, theta=theta)
test_dataset = SequenceDataset(num_samples=20000, theta=theta)

# Dataloaders
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

# Define Model
class SSM(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
    ):
        super().__init__()

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                Model(d_model=d_model)
            )

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x, savedata = False, save_prefix = ''):
        """
        Input x is shape (B, d_input, L)
        """
        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        for layer in self.s4_layers:
            # Apply S4 block: we ignore the state input and output
            x = layer(x)
        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        x = x[..., -1]
        x = self.decoder(x)

        return x

print('==> Building model..')
model = SSM(
    d_input=1,
    d_output=1,
    d_model=2,
    n_layers=1,
    dropout=0
)

model = model.to(device)

def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    This version ensures that parameters registered with _optim attributes
    are included in separate optimizer groups with their specified hyperparameters.
    """
    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters (without special hyperparameters)
    general_params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(general_params, lr=lr, weight_decay=0)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    unique_hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]
    for hp in unique_hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group({"params": params, **hp})

    # Create a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in unique_hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

criterion = nn.MSELoss()
optimizer, scheduler = setup_optimizer(
    model, lr=args.lr, weight_decay=0, epochs=args.epochs
)

# Create a dictionary mapping parameter IDs to their names
param_name_map = {id(param): name for name, param in model.named_parameters()}

# Print parameter names in the optimizer
for group_idx, group in enumerate(optimizer.param_groups):
    print(f"Group {group_idx+1}:")
    for param in group['params']:
        param_id = id(param)
        param_name = param_name_map.get(param_id, "Unnamed Parameter")
        print(param_name)

# Training loop
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.6f' %
            (batch_idx, len(train_loader), train_loss/(batch_idx+1))
        )

# Evaluation function
def eval(epoch, dataloader, checkpoint=False, savefile=False):
    global best_acc
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        counter = 0
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.6f' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1))
            )
            counter += 1
    
    return eval_loss / counter


losses = []

pbar = tqdm(range(0, args.epochs))
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d' % (epoch))
    train()
    eval(epoch, val_loader, checkpoint=True)
    err = eval(epoch, test_loader)
    losses.append(err)
    scheduler.step()
