import os
import argparse

import torch
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Timer
from pytorch_lightning.strategies import DDPStrategy
from timm.utils import ModelEmaV2


from datasets.omol import get_omol_loaders
from models.baseline.esen.models.uma.escn_md import eSEN
from models.rapidash.rapidash import Rapidash
from models.platoformer.groups import PLATONIC_GROUPS
from models.platoformer.platoformer import PlatonicTransformer
from models.platoformer.platoformer_fourier import TetraFourierTransformer
from models.platoformer.block_fourier import TetraFourierRMSNormQuarterBatch
from utils import (CosineWarmupScheduler, MemoryMonitorCallback, RandomSOd, 
                   TimerCallback,format_batch_for_esen ,run_gc)


# Performance optimizations
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch._dynamo.config.cache_size_limit = 64


class OMolModel(pl.LightningModule):
    """Lightning module for QM9 molecular property prediction, supporting multiple model types."""

    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(args)
        
        # Setup rotation augmentation
        self.rotation_generator = RandomSOd(3)
        self.max_num_elements = 100  # Max atomic number
    
        # Base node embeddings
        if self.hparams.learned_atom_embedding:
            self.node_emb_dim = 128  # Base embedding dimension
            self.primary_node_embedding = torch.nn.Embedding(self.max_num_elements, self.node_emb_dim)
            torch.nn.init.uniform_(self.primary_node_embedding.weight.data, -0.001, 0.001)
        else:
            self.node_emb_dim = 92  # k-hot encoding dim

        in_channels_scalar = (
            self.node_emb_dim
            + 3 * ("coords" in self.hparams.scalar_features)   # x,y,z coordinates as scalars
            + 1 * ("charges" in self.hparams.scalar_features)  # charges as scalars
        )
        print(f"Input channels (scalar): {in_channels_scalar}")
       
        
        in_channels_vector = 0
        self.lifted = self.hparams.orientations != 0 or self.model_type == "platoformer"
        out_channels_scalar = 1
        out_channels_vec = 1 if self.hparams.direct_force_pred else 0
        
        # Model specification
        if self.hparams.model_type == "platoformer":
            if self.hparams.solid_name is not None:
                solid_name = self.hparams.solid_name
            elif self.hparams.equivariance == "Tn":
                solid_name = "trivial"
            elif self.hparams.equivariance == "SEn":
                solid_name = "tetrahedron"
            else:
                raise ValueError(f"Unsupported equivariance type: {self.hparams.equivariance}. "
                                 "Supported types are 'Tn' (trivial) and 'SEn' (tetrahedron).")
            
            if self.hparams.fourier_implementation:
                if solid_name != "tetrahedron": raise NotImplementedError()
                self.net = TetraFourierTransformer(
                    input_dim=in_channels_scalar, 
                    input_dim_vec=in_channels_vector,
                    hidden_dim=self.hparams.hidden_dim,
                    output_dim=out_channels_scalar,
                    output_dim_vec=out_channels_vec,
                    nhead=self.hparams.num_heads,
                    num_layers=self.hparams.layers,
                    ffn_dim_factor=4,
                    scalar_task_level="graph",
                    dropout=self.hparams.dropout,
                    norm_first=self.hparams.norm_first,
                    rope_sigma=self.hparams.freq_sigma,
                    ape_sigma=self.hparams.ape_sigma,
                    learned_freqs=self.hparams.learned_freqs,
                    spatial_dim=3,
                    dense_mode=self.hparams.dense_mode,
                    mean_aggregation=self.hparams.mean_aggregation,
                    attention=self.hparams.attention,
                    attention_type=self.hparams.attention_type,
                    post_pool_readout=self.hparams.post_pool_readout,
                    ffn_readout=self.hparams.ffn_readout
                )
            else:
                self.net = PlatonicTransformer(
                    input_dim=in_channels_scalar, 
                    input_dim_vec=in_channels_vector,
                    hidden_dim=self.hparams.hidden_dim,
                    output_dim=out_channels_scalar,
                    output_dim_vec=out_channels_vec,
                    nhead=self.hparams.num_heads,
                    num_layers=self.hparams.layers,
                    solid_name=solid_name,
                    ffn_dim_factor=4,
                    scalar_task_level="graph",
                    dropout=self.hparams.dropout,
                    norm_first=self.hparams.norm_first,
                    rope_sigma=self.hparams.freq_sigma,
                    ape_sigma=self.hparams.ape_sigma,
                    learned_freqs=self.hparams.learned_freqs,
                    spatial_dim=3,
                    dense_mode=self.hparams.dense_mode,
                    mean_aggregation=self.hparams.mean_aggregation,
                    attention=self.hparams.attention,
                    attention_type=self.hparams.attention_type,
                    post_pool_readout=self.hparams.post_pool_readout,
                    ffn_readout=self.hparams.ffn_readout
                )
            if self.hparams.compile:
                self.net = torch.compile(
                    self.net,
                    mode="max-autotune-no-cudagraphs",
                    dynamic=True,
                )
        elif self.hparams.model_type == "rapidash":
            self.net = Rapidash(
                input_dim=in_channels_scalar + in_channels_vector,
                hidden_dim=self.hparams.hidden_dim,
                output_dim=out_channels_scalar,
                output_dim_vec=out_channels_vec,
                num_layers=self.hparams.layers,
                edge_types=self.hparams.edge_types,
                equivariance=self.hparams.equivariance,
                ratios=self.hparams.ratios,
                dim=3,
                num_ori=self.hparams.orientations,
                basis_dim=self.hparams.basis_dim,
                basis_hidden_dim=self.hparams.basis_hidden_dim,
                degree=self.hparams.degree,
                widening_factor=self.hparams.widening,
                layer_scale=self.hparams.layer_scale,
                task_level='node',
                last_feature_conditioning=False,
                skip_connections=self.hparams.skip_connections,
                fixed_feature_dim=True # Automatically reduces hidden_dim by a factor self.hparams.orientations
            )
        elif self.hparams.model_type == "esen":
            # This uses eSEN-sm-d model config by default
            self.net = eSEN(direct_forces=self.hparams.direct_force_pred,otf_graph=True)

        # Initialize normalization parameters
        self.register_buffer('shift', torch.tensor(0.0, dtype=torch.float32))
        self.register_buffer('scale', torch.tensor(1.0, dtype=torch.float32))
        self.register_buffer('avg_num_nodes', torch.tensor(1.0, dtype=torch.float32))
        # Setup metrics - Convert from eV to meV (multiply by 1000)
        self.train_metric = torchmetrics.MeanAbsoluteError()
        self.train_metric_force = torchmetrics.MeanAbsoluteError()
        self.train_metric_energy_per_atom = torchmetrics.MeanAbsoluteError()
        
        self.valid_metric = torchmetrics.MeanAbsoluteError()
        self.valid_metric_force = torchmetrics.MeanAbsoluteError()
        self.valid_metric_energy_per_atom = torchmetrics.MeanAbsoluteError()
        
        self.test_metrics_energy = torchmetrics.MeanAbsoluteError()
        self.test_metrics_force = torchmetrics.MeanAbsoluteError()
        self.test_metrics_energy_per_atom = torchmetrics.MeanAbsoluteError()

      
        if self.hparams.use_ema:
            self.model_ema = ModelEmaV2(
                self.net,  
                decay=self.hparams.ema_decay
            )
    
    def forward(self, graph):
        graph = graph.to(self.device)
        x = []
        vec = []
        
        # Use learned embeddings when enabled
        if self.hparams.learned_atom_embedding:
            x.append(self.primary_node_embedding(graph.atomic_numbers))
        else:
            x.append(graph.x)  # Use k-hot encoding
    
        # Add scalar features
        if "coords" in self.hparams.scalar_features:
            x.append(graph.pos)
        if "charges" in self.hparams.scalar_features:
            x.append(graph.charges[:, None])
    
        # Add vector features
        if "coords" in self.hparams.vector_features:
            vec.append(graph.pos[:,None,:])
   
        # Combine features
        x = torch.cat(x, dim=-1) if x else None
        vec = torch.cat(vec, dim=1) if vec else None
        
        # Forward pass
        if self.hparams.model_type == "rapidash":
            pred_scalar, pred_vec = self.net(x, graph.pos, graph.edge_index, graph.batch, vec=vec)
        elif self.hparams.model_type == "platoformer":
            pred_scalar, pred_vec = self.net(x, graph.pos, graph.batch, vec=vec, avg_num_nodes=self.avg_num_nodes.to(graph.pos.device))
        elif self.hparams.model_type == "esen":
            graph = format_batch_for_esen(graph)
            pred_scalar, pred_vec = self.net(graph)
        if self.hparams.direct_force_pred:
            return pred_scalar.view(-1), pred_vec.squeeze(1)
        else:   
            return pred_scalar.view(-1)
      
    @torch.enable_grad()
    def pred_energy_and_force(self, graph):
        graph.pos = graph.pos.clone().requires_grad_(True)
        pred_energy = self(graph)
        sign = -1.0
        pred_force = sign * torch.autograd.grad(
            pred_energy,
            graph.pos,
            grad_outputs=torch.ones_like(pred_energy),
            create_graph=self.training,
            retain_graph=self.training,  
        )[0]

        if not self.training:
            pred_energy = pred_energy.detach()
            pred_force = pred_force.detach()
        
        return pred_energy, pred_force

    def training_step(self, graph, batch_idx):
        # Apply rotation augmentation if enabled
        if self.hparams.train_augm:
            batch_size = graph.batch.max().item() + 1
            rots = self.rotation_generator(n=batch_size).type_as(graph.pos)
            rot_per_sample = rots[graph.batch]
            graph.pos = torch.einsum('bij,bj->bi', rot_per_sample, graph.pos)
            graph.forces = torch.einsum('bij,bj->bi', rot_per_sample, graph.forces)
    
        if self.hparams.direct_force_pred:
            pred_energy, pred_force = self(graph)
        else:
            pred_energy, pred_force = self.pred_energy_and_force(graph)

        if self.hparams.esen_loss:
            target_energy_norm = (graph.energy - self.shift) / self.scale
            target_force_norm = graph.forces / self.scale
            pred_energy_per_atom = pred_energy / graph.num_atoms
            target_energy_per_atom = target_energy_norm / graph.num_atoms
            energy_loss = torch.mean(torch.abs(pred_energy_per_atom - target_energy_per_atom))
            force_loss = torch.mean((pred_force - target_force_norm) ** 2)
            loss = 10* energy_loss + 5 * self.hparams.lambda_F * force_loss
        else:
            energy_loss = torch.mean((pred_energy - ((graph.energy - self.shift) / self.scale))**2)
            force_loss = torch.mean(torch.sqrt(torch.sum((pred_force - graph.forces / self.scale)**2,-1)))
            loss = energy_loss + self.hparams.lambda_F * force_loss
        
        # Convert to meV and meV/Å for logging
        pred_energy_mev = (pred_energy * self.scale + self.shift) * 1000  # eV to meV
        true_energy_mev = graph.energy * 1000  # eV to meV
        pred_force_mev_ang = pred_force * self.scale * 1000  # eV/Å to meV/Å
        true_force_mev_ang = graph.forces * 1000  # eV/Å to meV/Å
        
        # Energy per atom in meV
        pred_energy_per_atom_mev = pred_energy_mev / graph.num_atoms
        true_energy_per_atom_mev = true_energy_mev / graph.num_atoms


        # Original metric calculation for other models
        self.train_metric(pred_energy_mev, true_energy_mev)
        self.train_metric_force(pred_force_mev_ang, true_force_mev_ang)
        self.train_metric_energy_per_atom(pred_energy_per_atom_mev, true_energy_per_atom_mev)
        
        if batch_idx % 1000 == 0: 
            self.log("train MAE (energy) [meV]", self.train_metric, prog_bar=False, on_step=True, on_epoch=False, sync_dist=True)
            self.log("train MAE (force) [meV/Å]", self.train_metric_force, prog_bar=False, on_step=True, on_epoch=False, sync_dist=True)
            self.log("train MAE (energy/atom) [meV]", self.train_metric_energy_per_atom, prog_bar=False, on_step=True, on_epoch=False, sync_dist=True)
   
        if batch_idx % 500 == 0:
            run_gc()

        # Update EMA after each training step
        if self.hparams.use_ema and self.training:
            self.model_ema.update(self.net)
        
        return loss

    def on_train_epoch_end(self):
        pass
    
    def validation_step(self, graph, batch_idx):

        if self.hparams.direct_force_pred:
            pred_energy, pred_force = self(graph)
        else:
            pred_energy, pred_force = self.pred_energy_and_force(graph)

        # Convert to meV and meV/Å for logging
        pred_energy_mev = (pred_energy * self.scale + self.shift) * 1000  # eV to meV
        true_energy_mev = graph.energy * 1000  # eV to meV
        pred_force_mev_ang = pred_force * self.scale * 1000  # eV/Å to meV/Å
        true_force_mev_ang = graph.forces * 1000  # eV/Å to meV/Å
        
        # Energy per atom in meV
        pred_energy_per_atom_mev = pred_energy_mev / graph.num_atoms
        true_energy_per_atom_mev = true_energy_mev / graph.num_atoms
        
        self.valid_metric(pred_energy_mev, true_energy_mev)
        self.valid_metric_force(pred_force_mev_ang, true_force_mev_ang)
        self.valid_metric_energy_per_atom(pred_energy_per_atom_mev, true_energy_per_atom_mev)

        if batch_idx % 250 == 0:
            run_gc()

    def on_validation_epoch_end(self):
        self.log("valid MAE (energy) [meV]", self.valid_metric, prog_bar=True, sync_dist=True)
        self.log("valid MAE (force) [meV/Å]", self.valid_metric_force, prog_bar=True, sync_dist=True)
        self.log("valid MAE (energy/atom) [meV]", self.valid_metric_energy_per_atom, prog_bar=True, sync_dist=True)
    
    def test_step(self, graph, batch_idx):
        with torch.enable_grad():
            if self.hparams.direct_force_pred:
                pred_energy, pred_force = self(graph)
            else:
                pred_energy, pred_force = self.pred_energy_and_force(graph)

        # Convert to meV and meV/Å for logging
        pred_energy_mev = (pred_energy * self.scale + self.shift) * 1000  # eV to meV
        true_energy_mev = graph.energy * 1000  # eV to meV
        pred_force_mev_ang = pred_force * self.scale * 1000  # eV/Å to meV/Å
        true_force_mev_ang = graph.forces * 1000  # eV/Å to meV/Å
        
        # Energy per atom in meV
        pred_energy_per_atom_mev = pred_energy_mev / graph.num_atoms
        true_energy_per_atom_mev = true_energy_mev / graph.num_atoms
        
        self.test_metrics_energy(pred_energy_mev, true_energy_mev)
        self.test_metrics_force(pred_force_mev_ang, true_force_mev_ang)
        self.test_metrics_energy_per_atom(pred_energy_per_atom_mev, true_energy_per_atom_mev)

    def on_test_epoch_end(self):
        self.log("test MAE (energy) [meV]", self.test_metrics_energy, prog_bar=True, sync_dist=True)
        self.log("test MAE (force) [meV/Å]", self.test_metrics_force, prog_bar=True, sync_dist=True)
        self.log("test MAE (energy/atom) [meV]", self.test_metrics_energy_per_atom, prog_bar=True, sync_dist=True)
  
    def configure_optimizers(self):
        """Configure optimizer with layer-wise learning rates."""
        if self.hparams.layer_wise_lr:
            # Base learning rate
            base_lr = self.hparams.lr
            
            # Create parameter groups with different learning rates
            param_groups = []
            
            # For PlatonicTransformer
            if self.hparams.model_type == "platoformer":
                embed_lr = base_lr * self.hparams.embed_lr_factor
                param_groups.append({
                    "params": self.net.x_embedder.parameters(),
                    "lr": embed_lr,
                    "weight_decay": self.hparams.weight_decay
                })
                
                num_layers = len(self.net.layers)
                for i, layer in enumerate(self.net.layers):
                    #  deeper layers learn faster
                    layer_factor = self.hparams.layer_lr_factor ** (num_layers - i - 1)
                    #  deeper layers learn slower
                    # layer_factor = self.hparams.layer_lr_factor ** i
                    
                    layer_lr = base_lr * layer_factor
                    param_groups.append({
                        "params": layer.parameters(),
                        "lr": layer_lr,
                        "weight_decay": self.hparams.weight_decay
                    })
              
                readout_lr = base_lr * self.hparams.readout_lr_factor
                param_groups.append({
                    "params": self.net.scalar_readout.parameters(),
                    "lr": readout_lr,
                    "weight_decay": self.hparams.weight_decay
                })
                if hasattr(self.net, 'vector_readout'):
                    param_groups.append({
                        "params": self.net.vector_readout.parameters(),
                        "lr": readout_lr,
                        "weight_decay": self.hparams.weight_decay
                    })
                
  
            optimizer = torch.optim.AdamW(param_groups)
            
            # Apply scheduler if specified
            if self.hparams.cosine_scheduler:
                scheduler = CosineWarmupScheduler(optimizer, self.hparams.warmup, self.trainer.max_epochs)
                return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "valid MAE"}
            else:
                return {"optimizer": optimizer, "monitor": "valid MAE"}
        else:
            decay = set()
            no_decay = set()
            whitelist_weight_modules = (torch.nn.Linear,)
            whitelist_weight_modules = (torch.nn.Linear,)
            blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, TetraFourierRMSNormQuarterBatch)

            for mn, m in self.named_modules():  # mn is module name, m is module instance
                for pn, p in m.named_parameters():  # pn is parameter name (e.g., 'weight', 'bias', 'freqs')
                    fpn = f'{mn}.{pn}' if mn else pn  # fpn is the full parameter name

                    if pn == 'freqs':
                        no_decay.add(fpn)
                    elif pn.endswith('bias') or ('layer_scale' in pn):
                        no_decay.add(fpn)
                    elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                        decay.add(fpn)
                    elif pn.endswith('kernel'):
                        decay.add(fpn)
                    elif isinstance(m, blacklist_weight_modules):
                        no_decay.add(fpn)
                    elif pn.endswith('weight_1d') or pn.endswith('weight_2d_1') or pn.endswith('weight_2d_2') or pn.endswith('weight_3d'):
                        decay.add(fpn)
                    # Parameters not matching any rule will be caught later and added to no_decay by default.

            param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}

            current_params_in_groups = decay | no_decay
            missing_params = param_dict.keys() - current_params_in_groups
            if missing_params:
                print(f"Warning: Parameters {missing_params} were not explicitly assigned to decay/no_decay by specific rules. Adding to no_decay by default.")
                no_decay.update(missing_params)

            inter_params = decay & no_decay
            union_params = decay | no_decay
            assert len(inter_params) == 0, f"Parameters {inter_params} found in both decay and no_decay sets!"
            
            # Ensure all learnable parameters are covered
            assert len(param_dict.keys() - union_params) == 0, f"Parameters {param_dict.keys() - union_params} not assigned to any optimizer group!"

            optim_groups = [
                {"params": [param_dict[p_name] for p_name in sorted(list(decay)) if p_name in param_dict], "weight_decay": self.hparams.weight_decay},
                {"params": [param_dict[p_name] for p_name in sorted(list(no_decay)) if p_name in param_dict], "weight_decay": 0.0},
            ]
            
            # Filter out empty groups (e.g., if 'decay' set is empty)
            optim_groups = [group for group in optim_groups if group["params"]]

            if not optim_groups and list(param_dict.keys()): # Should not happen if there are learnable params
                raise ValueError("No optimizer groups were created, but there are learnable parameters.")
            elif not optim_groups and not list(param_dict.keys()): # No learnable params
                print("Warning: No learnable parameters found for the optimizer.")

            optimizer = torch.optim.AdamW(optim_groups, lr=self.hparams.lr)
            if self.hparams.cosine_scheduler:
                scheduler = CosineWarmupScheduler(optimizer, self.hparams.warmup, self.trainer.max_epochs)
                return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "valid MAE"}
            else:
                return {"optimizer": optimizer, "monitor": "valid MAE"}



def main(args): #
    pl.seed_everything(args.seed)

    if pl.utilities.rank_zero_only.rank == 0:
        # Compute/load dataset statistics on main process
        train_loader, val_loader, test_loader, _, _ = get_omol_loaders(
            root=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            use_charges=False,
            seed=args.seed,
            debug_subset=args.debug_subset,
            referencing=args.referencing,
            include_hof=args.include_hof,
            scale_shift=args.scale_shift,
            recalculate=args.recalculate_stats,
            use_k_hot=args.use_khot_encoding,
            edge=False
        )
    else:
        # On other processes, wait for main process to finish and load from cache
        train_loader, val_loader, test_loader, _, _ = get_omol_loaders(
            root=args.data_dir,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            use_charges=False,
            seed=args.seed,
            debug_subset=args.debug_subset,
            referencing=args.referencing,
            include_hof=args.include_hof,
            scale_shift=args.scale_shift,
            recalculate=False,
            use_k_hot=args.use_khot_encoding,
            edge=False
        )

    if args.gpus > 0 and torch.cuda.is_available():
        accelerator = "gpu"
        devices = args.gpus
    else:
        accelerator = "cpu"
        devices = "auto"
        
    if args.log:
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs")
        logger = pl.loggers.WandbLogger(
            project=args.wandb_project_name,
            name=None, 
            config=vars(args), # Log all hyperparameters
            save_dir=save_dir
        )
    else:
        logger = None

    callbacks = [
        # Best overall energy performance
        pl.callbacks.ModelCheckpoint(
            monitor='valid MAE (energy) [meV]', 
            mode='min',
            filename='best-energy-{epoch:02d}',
            every_n_epochs=1,
            save_top_k=3
        ),
        # Best overall force performance  
        pl.callbacks.ModelCheckpoint(
            monitor='valid MAE (force) [meV/Å]', 
            mode='min',
            filename='best-force-{epoch:02d}',
            every_n_epochs=1,
            save_top_k=3
        ),
        pl.callbacks.ModelCheckpoint(
            filename='epoch-{epoch:02d}',
            every_n_epochs=args.save_every_n_epochs,
            save_top_k=-1,  # Save all periodic checkpoints
            save_last=True  # Also save the final checkpoint as 'last'
        ),
        # TimerCallback(),
        # MemoryMonitorCallback(log_frequency=500)
    ]
    if args.log:
        callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch'))
    if args.timer is not None:
        callbacks.append(Timer(duration=args.timer))
   
    resume_path = args.resume_ckpt if args.load_weights is None else None
    if args.load_weights:
        print(f"Loading model weights from: {args.load_weights}")
        model = OMolModel.load_from_checkpoint(
            checkpoint_path=args.load_weights,
            args=args 
        )
    else:
        model = OMolModel(args)

    if hasattr(train_loader.dataset, 'scale'):
        model.scale = torch.tensor(train_loader.dataset.scale).to(model.device)
        model.shift = torch.tensor(train_loader.dataset.shift).to(model.device)
        print(f"Set model scale: {model.scale}, shift: {model.shift}")

    trainer = pl.Trainer(
        logger=logger,
        max_epochs=args.epochs,
        callbacks=callbacks,
        gradient_clip_val=1.0,
        accelerator=accelerator,
        devices=devices,
        enable_progress_bar=args.enable_progress_bar,
        precision=args.precision,
        inference_mode=False,
        # strategy=DDPStrategy(find_unused_parameters=True)
    )

    if args.test_ckpt is None:
        trainer.fit(model, train_loader, val_loader, ckpt_path=resume_path)
        best_ckpt_path = callbacks[0].best_model_path if callbacks[0].best_model_path else "last"
        trainer.test(model, test_loader, ckpt_path=best_ckpt_path)

    else:
        model = OMolModel.load_from_checkpoint(args.test_ckpt, hparams_file=os.path.join(os.path.dirname(args.test_ckpt), "hparams.yaml"), args=args)  
        trainer.test(model, test_loader)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Omol Property Prediction Training')
    
     # Training parameters
    parser.add_argument('--epochs', type=int, default=30, help='Number of training epochs')
    parser.add_argument('--timer', type=str, default=None, help='Timer for training in string format, e.g., \"00:08:00:00\" for 8 hours')
    parser.add_argument('--warmup', type=int, default=5, help='Number of warmup epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('--lr', type=float, default=7e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-6, help='Weight decay')
    parser.add_argument('--seed', type=int, default=1, help='Random seed')
    parser.add_argument('--cosine_scheduler', type=eval, default=True, help='Use cosine annealing scheduler')
    parser.add_argument('--max_samples', type=int, default=1e8, help='Maximum number of samples to use from the dataset')
    parser.add_argument('--lambda_F', type=float, default=12.0, help='Weight for force loss in the total loss function')
    # Model architecture
    parser.add_argument('--model_type', type=str, default='platoformer', choices=['rapidash', 'platoformer','esen'], help='Type of model to use: rapidash or transformer.')
    parser.add_argument('--hidden_dim', type=eval, default=1152, help='Hidden dimension(s), for rapidash [256,256,256,256]')
    parser.add_argument('--layers', type=eval, default=14, help='Layers per scale, for rapidash [0, 1, 1, 1]')
    parser.add_argument('--equivariance', type=str, default="SEn", help='Type of equivariance')
    parser.add_argument('--solid_name', type=str, default='tetrahedron', help='Override solid name for PlatonicTransformer (overrides equivariance setting)')
    
    # Platonic Transformer specific parameters
    parser.add_argument('--num_heads', type=int, default=72, help='Number of attention heads (Transformer only).')
    parser.add_argument('--freq_sigma', type=float, default=0.5, help='Sigma for RFF positional encoding (Transformer only).')
    parser.add_argument('--ape_sigma', type=float, default=None, help='Sigma parameter for Absolute Positional Encoding (APE)')
    parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (Transformer only).')
    parser.add_argument('--norm_first', type=eval, default=True, help='Use LayerNorm before attention in Transformer.')
    parser.add_argument('--learned_freqs', type=eval, default=True, help='Use Rotary Position Embedding (RoPE) in Transformer.')
    parser.add_argument('--dense_mode', type=eval, default=False, help='Use dense mode for Rapidash.')
    parser.add_argument('--mean_aggregation', type=eval, default=False, help='Use dense mode for Rapidash.')
    parser.add_argument('--attention', type=eval, default=True, help= 'Use attention in PlatonicConv (Transformer only).')
    parser.add_argument('--attention_type', type=str, default='equivariant',
                        choices=['equivariant', 'invariant','invariant-equivariant',
                                 'equivariant-invariant','equivariant-equivariant', 'invariant-invariant'],
                        help='For dual input, first part is for first layer, second part is for subsequent layers.')
    parser.add_argument('--post_pool_readout', type=eval, default=False, help= 'Do the readout after pooling (Transformer only).')
    parser.add_argument('--ffn_readout', type=eval, default=True, help= 'Feed-forward readout (Transformer only).')
    parser.add_argument('--learned_atom_embedding', type=eval, default=False, help='Use AddRoPE instead of PlatonicRoPE in PlatonicConv.')
    parser.add_argument('--compile', type=eval, default=True, help= 'Use torch.compile')
    parser.add_argument('--fourier_implementation', type=eval, default=False, help='Parameterize model in Fourier domain')

    # Rapidash specific parameters
    parser.add_argument('--basis_dim', type=int, default=256, help='Basis dimension')
    parser.add_argument('--basis_hidden_dim', type=int, default=128, help='Hidden dimension of the basis function MLP')
    parser.add_argument('--orientations', type=int, default=8, help='Number of orientations')
    parser.add_argument('--degree', type=int, default=2, help='Polynomial degree')
    parser.add_argument('--edge_types', type=eval, default=["knn-8","knn-8", "knn-8", "fc"], help='Edge types')
    parser.add_argument('--ratios', type=eval, default=[0.25, 0.25, 0.25], help='Pooling ratios')
    parser.add_argument('--widening', type=int, default=4, help='Network widening factor')
    parser.add_argument('--layer_scale', type=eval, default=None, help='Layer scaling factor')
    parser.add_argument('--skip_connections', type=eval, default=False, help='Use U-Net style skip connections')
    
    # Training features
    parser.add_argument('--train_augm', type=eval, default=True, help='Use rotation augmentation')
    
    # Input features
    parser.add_argument('--scalar_features', type=eval, default=[], help='Features to use as scalars: ["coords"]')
    parser.add_argument('--vector_features', type=eval, default=[], help='Features to use as vectors: ["coords"]')
    parser.add_argument('--test_rot_trans', type=bool, default=False, help='Rotate and translate the test set for evaluation')
   
    # Data and logging
    parser.add_argument('--data_dir', type=str, default='/data/omol25/', help='Data directory')
    parser.add_argument('--debug_subset', type=int, default=None, help='Use a subset of the dataset for debugging')
    parser.add_argument('--recalculate_stats', type=eval, default=False, help='Recalculate dataset statistics')
    parser.add_argument('--referencing', type=eval, default=True, help='use per-atom referencing for the target energy')
    parser.add_argument('--include_hof', type=eval, default=False, help='Normalize the target property using HOF values')
    parser.add_argument('--scale_shift', type=eval, default=False, help='Use scale and shift normalization for the target property')
    parser.add_argument('--use_khot_encoding', type=eval, default=True, help='Use one-hot encoding for atom types')
    parser.add_argument('--direct_force_pred', type=eval, default=True, help='Use direct force prediction')
    
    # Sweep configuration
    parser.add_argument('--config', type=eval, default=None, help='Sweep configuration dictionary')
    parser.add_argument('--model_id', type=int, default=None, help='Model ID in case you would want to label the configuration')
    
    # System and checkpointing
    parser.add_argument('--log', type=eval, default=True, help='Enable logging')
    parser.add_argument('--wandb_project_name', type=str, default='Platonic-OMol', help='WandB project name')
    parser.add_argument('--gpus', type=int, default=4, help='Number of GPUs to use (0 for CPU).')
    parser.add_argument('--precision', type=str, default='32', choices=['16-mixed', 'bf16-mixed', '32'], help='Precision for training: 16 or 32.')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers.')
    parser.add_argument('--enable_progress_bar', type=eval, default=True, help='Show progress bar during training.')
    parser.add_argument('--test_ckpt', type=str, default=None, help='Path to a checkpoint for testing.')
    parser.add_argument('--resume_ckpt', type=str, default=None, help='Path to a checkpoint to resume training from.')
    parser.add_argument('--load_weights', type=str, default=None, help='Path to a checkpoint to load model weights from, starting a new run.')
    parser.add_argument('--save_every_n_epochs', type=int, default=3, help='Save checkpoint every N epochs (minimum 1).')
    # Add to argument parser
    parser.add_argument('--layer_wise_lr', type=eval, default=False, help='Use layer-wise learning rates')
    parser.add_argument('--layer_lr_factor', type=float, default=0.9, help='Learning rate decay/growth factor between layers')
    parser.add_argument('--embed_lr_factor', type=float, default=0.5, help='Learning rate factor for embedding layer')
    parser.add_argument('--readout_lr_factor', type=float, default=1.2, help='Learning rate factor for readout layers')  
    # Add missing flags
    parser.add_argument('--use_ema', type=eval, default=False, help='Use Exponential Moving Average (EMA) of model parameters')
    parser.add_argument('--ema_decay', type=float, default=0.999, help='Decay rate for EMA')
    parser.add_argument('--esen_loss', type=eval, default=False, help='Use eSEN-style loss function (energy per atom + force MSE)')
    
    args = parser.parse_args()

    # Validate save_every_n_epochs
    if args.save_every_n_epochs < 1:
        raise ValueError(f"save_every_n_epochs must be at least 1, got {args.save_every_n_epochs}")

    if args.config is not None: #
        for key, value in args.config.items():
            setattr(args, key, value)

    main(args)
