import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import argparse
from tqdm.auto import tqdm
import os
import math

# Argument parser
parser = argparse.ArgumentParser(description='Train a Sequential Model')
parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
parser.add_argument('--epochs', default=3, type=int, help='Training epochs')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--n_layers', default=4, type=int, help='Number of layers')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
args = parser.parse_args()

# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'

WIDTH = 32

# Define dataset class
class SequenceDataset(data.Dataset):
    def __init__(self, num_samples, sig1, sig2, sequence_length=16):
        self.num_samples = num_samples
        self.sig1 = sig1
        self.sig2 = sig2
        self.sequence_length = sequence_length
        self.data = []

        for _ in range(num_samples):
            # Sample standard normal sequence
            seq = np.zeros((WIDTH,sequence_length))

            # Sample special values
            val1 = np.random.normal(0, sig1)
            val2 = np.random.normal(0, sig2, size=(WIDTH))

            # Replace in the sequence
            seq[:,sequence_length-2] += val1
            seq[:,0] = val2

            self.data.append((torch.tensor(seq, dtype=torch.float32),
                              torch.tensor(val1*0, dtype=torch.float32)))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

# Generate datasets
sig1 = 0.1
sig2 = 1
train_dataset = SequenceDataset(num_samples=10000, sig1=sig1, sig2=sig2)
val_dataset = SequenceDataset(num_samples=1000, sig1=sig1, sig2=sig2)
test_dataset = SequenceDataset(num_samples=2000, sig1=sig1, sig2=sig2)

# Dataloaders
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

# Define Model
class SSM(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
        scale=1,
    ):
        super().__init__()

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                Model(d_model=d_model)
            )

        # Linear decoder
        self.decoder = nn.Parameter(torch.randn(WIDTH))
        self.scale=scale

    def forward(self, x, savedata = False, save_prefix = ''):
        """
        Input x is shape (B, d_input, L)
        """
        answer = x[:,0,16-2]
        x = x.transpose(-1, -2)  # (B, d_model, L) -> (B, L, d_model)
        for layer in self.s4_layers:
            # Apply S4 block: we ignore the state input and output
            x = layer(x)
        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        x = x[..., -1]
        x = self.decoder * x * self.scale
        
        diff = torch.abs(answer.unsqueeze(1) - x)  # shape (B, D)
        min_diff, _ = diff.min(dim=1)  # shape (B,)

        return min_diff

def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    This version ensures that parameters registered with _optim attributes
    are included in separate optimizer groups with their specified hyperparameters.
    """
    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters (without special hyperparameters)
    general_params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(general_params, lr=lr, weight_decay=0)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    unique_hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]
    for hp in unique_hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group({"params": params, **hp})

    # Create a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in unique_hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

criterion = nn.MSELoss()
optimizer, scheduler = setup_optimizer(
    model, lr=args.lr, weight_decay=0, epochs=args.epochs
)

# Create a dictionary mapping parameter IDs to their names
param_name_map = {id(param): name for name, param in model.named_parameters()}

# Print parameter names in the optimizer
for group_idx, group in enumerate(optimizer.param_groups):
    print(f"Group {group_idx+1}:")
    for param in group['params']:
        param_id = id(param)
        param_name = param_name_map.get(param_id, "Unnamed Parameter")
        print(param_name)

# Training loop
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.6f' %
            (batch_idx, len(train_loader), math.sqrt(train_loss/(batch_idx+1)))
        )

# Evaluation function
def eval(epoch, dataloader, checkpoint=False, savefile=False):
    global best_acc
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        counter = 0
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.6f' %
                (batch_idx, len(dataloader), math.sqrt(eval_loss/(batch_idx+1)))
            )
            counter += 1

    state = {
        'model': model.state_dict(),
        'epoch': epoch,
    }
    
    return eval_loss / counter


sig_range = np.exp(np.arange(-4.5, 5.0, 0.5))  # 19 values
test_errors = np.zeros((len(sig_range), len(sig_range)))

for i, sig1 in enumerate(sig_range):
    for j, sig2 in enumerate(sig_range):
        print(f"Training for sig1 = {sig1:.4f}, sig2 = {sig2:.4f}")

        # Regenerate datasets
        train_dataset = SequenceDataset(num_samples=10000, sig1=sig1, sig2=sig2)
        val_dataset = SequenceDataset(num_samples=1000, sig1=sig1, sig2=sig2)
        test_dataset = SequenceDataset(num_samples=2000, sig1=sig1, sig2=sig2)

        train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        # Reinitialize model and optimizer
        model = SSM(
            d_input=1,
            d_output=1,
            d_model=WIDTH,
            n_layers=1,
            dropout=0,
            scale=1
        ).to(device)

        criterion = nn.MSELoss()
        optimizer, scheduler = setup_optimizer(
            model, lr=args.lr, weight_decay=0, epochs=args.epochs
        )
        
        if torch.cuda.device_count() > 1:
            print(f"Let's use {torch.cuda.device_count()} GPUs!")
            # Wrap the model with DataParallel
            model = nn.DataParallel(model)

        # Training loop
        for epoch in range(args.epochs):
            train()
            eval(epoch, val_loader)
            test_loss = eval(epoch, test_loader)
            scheduler.step()

        # Save final test loss
        test_errors[i, j] = test_loss