"""Pytorch script for training MotiFlow."""
import os
# This line magically changes some tensors to double precision
# so we need to reset the default dtype later.
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import copy
import logging
import time
from collections import defaultdict, deque

import GPUtil
import hydra
import numpy as np
import pandas as pd
import torch
import tree
import wandb
from einops import rearrange
import motiflow.utils.experiments_utils as eu
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from torch.nn import DataParallel as DP
from motiflow.utils.so3_helpers import hat_inv, pt_to_identity
from motiflow.data import fragment_dataset
from motiflow.data import utils as du
from motiflow.models import se3_fm
from motiflow.models.components import network
from motiflow.utils import rigid_utils as ru
from motiflow.utils.rigid_helpers import assemble_rigid_mat
from motiflow.utils.reconstruction import MoleculeScorer, reconstruct_atoms

def create_dense_symmetry_lib(library_dict, vocab_size, s_max):
    """
    Converts dictionary {cid: (rots, mask)} to dense tensors.
    Ensures size is vocab_size + 1 to handle the Mask token.
    """
    # Initialize with Identity (valid symmetry count = 1)
    # Shape: [V+1, S, 3, 3]
    lib_rots = torch.eye(3).view(1, 1, 3, 3).repeat(vocab_size + 1, s_max, 1, 1)
    lib_mask = torch.zeros(vocab_size + 1, s_max, dtype=torch.bool)
    
    # Default: Index 0 and others have at least Identity valid
    lib_mask[:, 0] = True 
    
    for cid, entry in library_dict.items():
        if cid <= vocab_size:
            # entry is (rots, mask) from QM9FragmentDataset
            r, m = entry
            lib_rots[cid] = r
            lib_mask[cid] = m
            
    return lib_rots, lib_mask


class Experiment:
    def __init__(
        self,
        *,
        conf: DictConfig,
        model = None,
    ):
        """Initialize experiment.

        Args:
            exp_cfg: Experiment configuration.
        """
        self.first_batch = None
        self._log = logging.getLogger(__name__)
        self._available_gpus = "".join(
            [str(x) for x in GPUtil.getAvailable(order="memory", limit=8)]
        )

        # Configs
        self._conf = conf
        self._exp_conf = conf.experiment
        if HydraConfig.initialized() and "num" in HydraConfig.get().job:
            self._exp_conf.name = f"{self._exp_conf.name}_{HydraConfig.get().job.num}"
        self._fm_conf = conf.flow_matcher
        self._model_conf = conf.model
        self._data_conf = conf.data
        # inject conditioning config into model config
        OmegaConf.set_struct(self._model_conf, False)
        self._model_conf.conditioning = conf.conditioning
        OmegaConf.set_struct(self._model_conf, True)
        self._wandb_conf = conf.wandb
        self._use_wandb = self._wandb_conf.use_wandb
        self._use_ddp = self._exp_conf.use_ddp
        # 1. initialize ddp info if in ddp mode
        # 2. silent rest of logger when use ddp mode
        # 3. silent wandb logger
        # 4. unset checkpoint path if rank is not 0 to avoid saving checkpoints and evaluation
        print(f"Number of threads {self._exp_conf.torch_num_threads}")
        torch.set_num_threads(self._exp_conf.torch_num_threads)
        # reduce matmul precision for better performance on GPU
        torch.set_float32_matmul_precision("medium")
        torch.set_default_dtype(torch.float32)
        torch.backends.cuda.matmul.allow_tf32 = True
        self._master_proc = True

        ckpt_model, ckpt_opt = self.handle_warmstart(conf)
        
        # Initialize experiment objects
        self.fragment_library = torch.load(conf.data.library_path)
        self._flow_matcher = se3_fm.SE3FlowMatcher(self._fm_conf)
        self._model = model
        if self._model is None:
            if self._model_conf.model_name == "ipa":
                self._model = network.VectorFieldNetwork(self._model_conf, self.flow_matcher, self.fragment_library)
            else:
                raise ValueError(f"Unrecognized model name {self._model_conf.model_name}")
            if ckpt_model is not None:
                ckpt_model = {k.replace("module.", ""): v for k, v in ckpt_model.items()}
                ckpt_model = {
                    k.replace("score_model.", "vectorfield."): v
                    for k, v in ckpt_model.items()
                }
                self._model.load_state_dict(ckpt_model, strict=True)

            num_parameters = sum(p.numel() for p in self._model.parameters())
            self._exp_conf.num_parameters = num_parameters
            self._log.info(f"Number of model parameters {num_parameters}")
            self._optimizer = torch.optim.Adam(
                self._model.parameters(), lr=self._exp_conf.learning_rate
            )
            if ckpt_opt is not None:
                self._optimizer.load_state_dict(ckpt_opt)
                if conf.experiment.use_gpu:
                    for state in self._optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
                                
        self.scorer = MoleculeScorer()

        if self._exp_conf.full_ckpt_dir is not None:
            # Set-up checkpoint location
            ckpt_dir = self._exp_conf.full_ckpt_dir
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir, exist_ok=True)
            self._exp_conf.full_ckpt_dir = ckpt_dir
            self._log.info(f"Checkpoints saved to: {ckpt_dir}")
        else:
            self._log.info("Checkpoint not being saved.")
        if self._exp_conf.eval_dir is not None:
            eval_dir = os.path.join(
                self._exp_conf.eval_dir,
                self._exp_conf.name,
                self.conf.start_time_string,
            )
            self._exp_conf.eval_dir = eval_dir
            self._log.info(f"Evaluation saved to: {eval_dir}")
        else:
            self._exp_conf.eval_dir = os.devnull
            self._log.info("Evaluation will not be saved.")
        self._aux_data_history = deque(maxlen=100)

        # DEBUG Variables
        self._first_train_feats = None
        self._global_rank = 0

    def handle_warmstart(self, conf):
        # Warm starting
        ckpt_model = None
        ckpt_opt = None
        self.trained_epochs = 0
        self.trained_steps = 0
        if not conf.experiment.warm_start:
            return None, None

        assert conf.experiment.warm_start in ["auto", "force"]

        # check path exists
        full_ckpt_dir = conf.experiment.full_ckpt_dir
        print(f"THIS IS THE FULL CONF!\n {conf.experiment}")
        if full_ckpt_dir is not None and not os.path.exists(full_ckpt_dir):
            if conf.experiment.warm_start == "auto":
                return None, None
            if conf.experiment.warm_start == "force":
                raise ValueError(f"full_ckpt_dir {full_ckpt_dir} does not exist")

        ckpt_files = [x for x in os.listdir(full_ckpt_dir) if "pkl" in x or ".pth" in x]
        if len(ckpt_files) == 0:
            if conf.experiment.warm_start == "auto":
                return None, None
            if conf.experiment.warm_start == "force":
                raise ValueError(f"full_ckpt_dir {full_ckpt_dir} has no checkpoints")

        self._log.info(f"Warm starting from: {full_ckpt_dir}")

        ckpt_name = ckpt_files[0]
        if len(ckpt_files) != 1:
            paths = [os.path.join(full_ckpt_dir, ckpt_file) for ckpt_file in ckpt_files]
            ckpt_name = max(paths, key=os.path.getmtime).split("/")[-1]
            self._log.info("Loading most recent ckpt")
        ckpt_path = os.path.join(full_ckpt_dir, ckpt_name)
        self._log.info(f"Loading checkpoint from {ckpt_path}")
        ckpt_pkl = du.read_pkl(ckpt_path, use_torch=True)
        ckpt_model = ckpt_pkl["model"]

        if conf.experiment.use_warm_start_conf:
            OmegaConf.set_struct(conf, False)
            conf = OmegaConf.merge(conf, ckpt_pkl["conf"])
            OmegaConf.set_struct(conf, True)
        conf.experiment.warm_start = full_ckpt_dir

        # For compatibility with older checkpoints.
        if "optimizer" in ckpt_pkl:
            ckpt_opt = ckpt_pkl["optimizer"]
        if "epoch" in ckpt_pkl:
            self.trained_epochs = ckpt_pkl["epoch"]
        if "step" in ckpt_pkl:
            self.trained_steps = ckpt_pkl["step"]
        return ckpt_model, ckpt_opt

    @property
    def flow_matcher(self):
        return self._flow_matcher

    @property
    def model(self):
        return self._model

    @property
    def conf(self):
        return self._conf

    def create_dataset(self):
        train_dataset = fragment_dataset.FragmentDataset(
            split="train",
            data_conf=self._data_conf,
            is_training=True,
        )
        valid_dataset = fragment_dataset.FragmentDataset(
            split="val",
            data_conf=self._data_conf,
            is_training=False,
        )
        train_sampler = fragment_dataset.TrainSampler(
            dataset=train_dataset,
            batch_size=self._exp_conf.batch_size,
        )
        valid_sampler = None
        num_workers = self._exp_conf.num_loader_workers

        train_loader = du.create_data_loader(
            train_dataset,
            sampler=train_sampler,
            np_collate=False,
            length_batch=True,
            batch_size=self._exp_conf.batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
            max_squared_res=self._exp_conf.max_squared_res,
            prefetch_factor=self._exp_conf.prefetch_factor,
        )

        valid_loader = du.create_data_loader(
            valid_dataset,
            sampler=valid_sampler,
            np_collate=False,
            length_batch=True,
            batch_size=self._exp_conf.eval_batch_size,
            shuffle=True if self._exp_conf.eval_freq == 1 else False, # shuffle only if final evaluation
            num_workers=self._exp_conf.num_loader_workers,
            drop_last=False,
        )

        return train_loader, valid_loader, train_sampler, valid_sampler, train_dataset.lib_symmetries

    def init_wandb(self):
        self._log.info("Initializing Wandb.")
        conf_dict = OmegaConf.to_container(self._conf, resolve=True)
        if self._exp_conf.run_id is None:
            self._exp_conf.run_id = wandb.util.generate_id()
        wandb.init(
            project=self._wandb_conf.project,
            entity=self._wandb_conf.entity,
            name=self._exp_conf.name,
            config=dict(eu.flatten_dict(conf_dict)),
            dir=self._wandb_conf.dir,
            id=self._exp_conf.run_id,
            tags=self._wandb_conf.tags,
            group=self._wandb_conf.group,
            mode="offline" if self._wandb_conf.offline else "online",
            job_type=self._wandb_conf.job_type,
            resume="auto" if self._exp_conf.warm_start is not None else None,
        )
        self._wandb_conf.dir = wandb.run.dir
        self._log.info(
            f"Wandb: run_id={self._exp_conf.run_id}, run_dir={self._wandb_conf.dir}"
        )

    def start_training(self, return_logs=False):
        # print(f"Start-training-"*10)
        # Set environment variables for which GPUs to use.
        if HydraConfig.initialized() and "num" in HydraConfig.get().job:
            replica_id = int(HydraConfig.get().job.num)
        else:
            replica_id = 0
        if self._use_wandb and replica_id == 0:
            self.init_wandb()
        assert not self._exp_conf.use_ddp or self._exp_conf.use_gpu

        # GPU mode
        if torch.cuda.is_available() and self._exp_conf.use_gpu:
            # single GPU mode
            if self._exp_conf.num_gpus == 1:
                try:
                    gpu_id = self._available_gpus[replica_id]
                    device = f"cuda:{gpu_id}"
                except IndexError:
                    device = "cuda:0"
                    self._log.warning("Error on available gpus, trying with device 0")
                self._model = self.model.to(device)
                self._log.info(f"Using device: {device}")
            # muti gpu mode
            elif self._exp_conf.num_gpus > 1:
                # DP mode
                device_ids = [
                    f"cuda:{i}"
                    for i in self._available_gpus[: self._exp_conf.num_gpus]
                ]
                if len(self._available_gpus) > self._exp_conf.num_gpus:
                    raise ValueError(
                        f"require {self._exp_conf.num_gpus} GPUs, but only {len(self._available_gpus)} GPUs available "
                    )
                self._log.info(
                    f"Multi-GPU training on GPUs in DP mode: {device_ids}"
                )
                gpu_id = self._available_gpus[replica_id]
                device = f"cuda:{gpu_id}"
                self._model = DP(self._model, device_ids=device_ids)
                self._model = self.model.to(device)
        else:
            device = "cpu"
            self._log.info(f"Using device: {device}")
            self._model = self.model.to(device)

        self._model.train()
        (
            train_loader,
            valid_loader,
            train_sampler,
            valid_sampler,
            lib_symmetries,
        ) = self.create_dataset()
        
        # 2. Create Dense Library Tensor
        self._log.info("Creating dense symmetry library for model...")
        lib_rots, lib_mask = create_dense_symmetry_lib(
            lib_symmetries, 
            self._data_conf.vocab_size, 
            self._data_conf.S_max
        )
        
        # 3. Inject into Model
        self._model.set_symmetry_library(lib_rots, lib_mask, device=device)

        logs = []
        for epoch in range(self.trained_epochs, self._exp_conf.num_epoch):
            if train_sampler is not None:
                train_sampler.set_epoch(epoch)
            if valid_sampler is not None:
                valid_sampler.set_epoch(epoch)
            self.trained_epochs = epoch
            epoch_log = self.train_epoch(
                train_loader, valid_loader, device, return_logs=return_logs
            )
            if return_logs:
                logs.append(epoch_log)

        self._log.info("Done")
        if return_logs:
            return logs
        return 0

    def update_fn(self, data, debug=False):
        """Updates the state using some data and returns metrics."""
        self._optimizer.zero_grad()
        # sample time t for the forward process on the batch
        data = self._flow_matcher.sample_times(data) # shape (B,)
        # Apply flow matching forward marginal to get noisy state X_t
        data = self._flow_matcher.forward_marginal(data)
        loss, aux_data = self.loss_fn(data)
        loss.backward()

        if debug:
            for name, param in self._model.named_parameters():
                if param.grad is None:
                    print(f"NO GRAD FOR PARAMETERS  {name}")

        torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0)
        self._optimizer.step()
        return loss, aux_data

    def train_epoch(self, train_loader, valid_loader, device, return_logs=False):
        log_lossses = defaultdict(list)
        global_logs = []
        log_time = time.time()
        step_time = time.time()

        for train_feats in train_loader:
            if "dummy_batch" in train_feats:
                self._log.error("Dummy batch")
                continue
            train_feats = tree.map_structure(
                lambda x: x.to(device) if torch.is_tensor(x) else x, train_feats
            )
            loss, aux_data = self.update_fn(train_feats)

            if return_logs:
                global_logs.append(loss)
            for k, v in aux_data.items():
                log_lossses[k].append(du.move_to_np(v))
            self.trained_steps += 1

            # Logging to terminal
            if (
                self.trained_steps == 1
                or self.trained_steps % self._exp_conf.log_freq == 0
            ):
                elapsed_time = time.time() - log_time
                log_time = time.time()
                step_per_sec = self._exp_conf.log_freq / elapsed_time
                rolling_losses = tree.map_structure(np.mean, log_lossses)
                loss_log = " ".join(
                    [
                        f"{k}={v[0]:.4f}"
                        for k, v in rolling_losses.items()
                        if "batch" not in k
                    ]
                )
                self._log.info(
                    f"[{self.trained_steps}]: {loss_log}, steps/sec={step_per_sec:.5f}"
                )
                log_lossses = defaultdict(list)
            # Take checkpoint
            if ((self.trained_steps % self._exp_conf.ckpt_freq) == 0) or (
                self._exp_conf.early_ckpt and self.trained_steps == 10
            ):
                if self._master_proc and self._exp_conf.full_ckpt_dir is not None:
                    self._log.info("Take checkpoint")
                    ckpt_path = os.path.join(
                        self._exp_conf.full_ckpt_dir, f"step_{self.trained_steps}.pth"
                    )
                    du.write_checkpoint(
                        ckpt_path,
                        self.model.state_dict(),
                        self._conf,
                        self._optimizer.state_dict(),
                        self.trained_epochs,
                        self.trained_steps,
                        logger=self._log,
                        use_torch=True,
                    )
            ckpt_metrics = None
            eval_time = None
            if ((self.trained_steps % self._exp_conf.eval_freq) == 0) or (
                self._exp_conf.early_ckpt and self.trained_steps == 10
            ):
                if self._master_proc:
                    # Run evaluation
                    start_time = time.time()
                    eval_dir = os.path.join(
                        self._exp_conf.eval_dir, f"step_{self.trained_steps}"
                    )
                    os.makedirs(eval_dir, exist_ok=True)

                    self._log.info(
                        f"Running evaluation at EP "
                        f"{self.trained_epochs} step {self.trained_steps} in {eval_dir}"
                    )

                    ckpt_metrics = self.eval_fn(
                        eval_dir,
                        valid_loader,
                        device,
                        self._exp_conf.max_eval_batches,
                    )
                    eval_time = time.time() - start_time
                    self._log.info(f"Finished evaluation in {eval_time:.2f}s")

            # Remote log to Wandb.
            if self._use_wandb and self._master_proc:
                step_time = time.time() - step_time
                example_per_sec = self._exp_conf.batch_size / step_time
                step_time = time.time()
                wandb_logs = {
                    "loss": loss,
                    "rotation_loss": aux_data["rot_loss"],
                    "translation_loss": aux_data["trans_loss"],
                    "cat_loss": aux_data["cat_loss"],
                    "batch_size": aux_data["examples_per_step"],
                    "frag_length": aux_data["frag_length"],
                    "examples_per_sec": example_per_sec,
                    "num_epochs": self.trained_epochs,
                }

                # Stratified losses
                wandb_logs.update(
                    eu.t_stratified_loss(
                        du.move_to_np(train_feats["t"]),
                        du.move_to_np(aux_data["batch_rot_loss"]),
                        loss_name="rot_loss",
                    )
                )

                wandb_logs.update(
                    eu.t_stratified_loss(
                        du.move_to_np(train_feats["t"]),
                        du.move_to_np(aux_data["batch_trans_loss"]),
                        loss_name="trans_loss",
                    )
                )
                
                wandb_logs.update(
                    eu.t_stratified_loss(
                        du.move_to_np(train_feats["t"]),
                        du.move_to_np(aux_data["batch_cat_loss"]),
                        loss_name="cat_loss",
                    )
                )

                if ckpt_metrics is not None:
                    wandb_logs["eval_time"] = eval_time
                    for metric_name in ckpt_metrics.columns:
                        if pd.api.types.is_numeric_dtype(ckpt_metrics[metric_name]):
                            wandb_logs[metric_name] = ckpt_metrics[metric_name].mean()

                wandb.log(wandb_logs, step=self.trained_steps)

            if torch.isnan(loss):
                if self._use_wandb:
                    wandb.alert(
                        title="Encountered NaN loss",
                        text=f"Loss NaN after {self.trained_epochs} epochs, {self.trained_steps} steps",
                    )
                raise Exception("NaN encountered")

        if return_logs:
            return global_logs

    def eval_fn(
        self,
        eval_dir,
        valid_loader,
        device,
        max_batches,
    ):
        """
        Runs inference on the validation set and collects generated rigid frames.
        """
        # Ensure eval directory exists
        os.makedirs(eval_dir, exist_ok=True)

        generated_data = []
        targets_list = []

        # Iterate over validation loader
        for batch_idx, valid_feats in enumerate(valid_loader):
            if max_batches is not None and batch_idx >= max_batches:
                self._log.info(f"Reached max batches ({max_batches}). Stopping inference.")
                break

            # Move to device for consistency (not strictly needed for pass_through but harmless)
            valid_feats = tree.map_structure(lambda x: x.to(device) if torch.is_tensor(x) else x, valid_feats)

            B, N = valid_feats["frag_ids"].shape
            flow_mask = valid_feats["frag_mask"]
            
            # Extract GT for imputation of fixed components: if any of the flow components are disabled, we need the true features
            gt_rigids = ru.Rigid.from_tensor_7(valid_feats["rigids_0"])
            gt_cat = torch.nn.functional.one_hot(
                valid_feats["frag_ids"].long(), 
                num_classes=self._fm_conf.cat.vocab_size
            ).float()
            
            # Collect Evaluation Targets
            if "eval_target" in valid_feats:
                # move to CPU numpy
                targets_list.append(du.move_to_np(valid_feats["eval_target"]))

            # We pass GT as 'impute'. 
            # The SE3FlowMatcher logic is: 
            # if flag is False -> return impute
            # if flag is True -> return noise (unless impute_mask forces otherwise)
            # Here we provide impute_mask=None, so it purely relies on the config flags.
            noise_batch = self._flow_matcher.sample_ref(
                n_samples=B,
                n_fragments=N,
                device=device,
                flow_mask=flow_mask,
                impute_mask=None, 
                impute=gt_rigids,
                impute_cat=gt_cat,
                as_tensor_7=True
            )
            
            # Inject noise into batch (inference_fn expects 'rigids_t' and 'cat_t' to start)
            valid_feats.update(noise_batch)

            frag_ids = du.move_to_np(valid_feats["frag_ids"])
            frag_mask = du.move_to_np(valid_feats["frag_mask"].bool())

            infer_out = self.inference_fn(
                valid_feats
            )

            # Get the final step (t=0)
            final_rigids = infer_out["rigid_traj"][-1]  # Shape: [B, N, 7]
            # if we sampled the fragment classes
            if self._fm_conf.flow_cat:
                cat_ids_pred = infer_out['cat_ids']  # [B, N]

            # Collect results for this batch
            batch_size = final_rigids.shape[0]
            for i in range(batch_size):
                num_frags = int(np.sum(frag_mask[i]).item())

                if self._fm_conf.flow_cat:
                    mol_frag_ids = cat_ids_pred[i][:num_frags].astype(np.int64)
                else: # if we do not sample fragment classes, use GT frag_ids
                    mol_frag_ids = frag_ids[i][:num_frags]
                mol_final_rigids = final_rigids[i][:num_frags]

                generated_data.append({
                    "frag_ids": mol_frag_ids,
                    "rigids": mol_final_rigids,  # 7-dim Quat+Trans
                    "num_frags": num_frags
                })

        # 2. Reconstruction (Fragments -> Atoms)
        self._log.info("Reconstructing atomic coordinates from collected rigids...")
        pos_list, z_list, bond_overrides = reconstruct_atoms(generated_data, self.fragment_library)
        
        # Flatten Targets
        if targets_list:
            all_targets = np.concatenate(targets_list, axis=0)
            # Ensure length matches generated samples (in case dropped last batch etc)
            all_targets = all_targets[:len(pos_list)]
        else:
            all_targets = None

        # 3. Save Raw generated data
        save_path = os.path.join(eval_dir, "generated_rigids.pt")
        torch.save(generated_data, save_path)
        # 3.5. Save reconstructed atomic data
        save_path_atoms = os.path.join(eval_dir, "generated_atomic_data.pt")
        torch.save({
            "generated_pos": pos_list,
            "generated_z": z_list,
            "bond_overrides": bond_overrides,
        }, save_path_atoms)

        # 4. Scoring
        metrics_dict = {}
        if self._data_conf.dataset_name == "qm9_fragmented":
            self._log.info("Computing QM9 metrics...")
            # QM9 fragments do not store intra-fragment bonds since the interatomic distances reconstruct them correctly from bond tables
            metrics_dict.update(self.scorer.score_qm9(pos_list, z_list, bond_overrides))
            # Check if we are doing conditional generation
            cond_type = self._model_conf.conditioning.type
            if cond_type == "composition" and all_targets is not None:
                self._log.info("Computing Composition metrics...")
                cond_metrics = self.scorer.score_composition(z_list, all_targets)
                metrics_dict.update(cond_metrics)
            elif cond_type == "structure" and all_targets is not None:
                self._log.info("Computing Structure (Tanimoto) metrics...")
                cond_metrics = self.scorer.score_structure(pos_list, z_list, all_targets)
                metrics_dict.update(cond_metrics)
        elif self._data_conf.dataset_name == "geom_fragmented":
            self._log.info("Computing GEOM-Drugs metrics...")
            # GEOM fragments do store the intra-fragment topology, the bonds within fragments will be enforced
            metrics_dict = self.scorer.score_geom_drugs(pos_list, z_list, bond_overrides)
        elif self._data_conf.dataset_name == "qmugs_fragmented":
            self._log.info("Computing QMugs metrics...")
            # QMugs fragments do store the intra-fragment topology, the bonds within fragments will be enforced
            metrics_dict, gen_mols = self.scorer.score_geom_drugs(pos_list, z_list, bond_overrides, return_molecules=True)
            # Save generated dictionary with SMILES and molecules
            torch.save(gen_mols, os.path.join(eval_dir, "generated_molecules.pt"))
            
        self._log.info(f"Evaluation Results: {metrics_dict}")
        metrics_df = pd.DataFrame([metrics_dict])
        metrics_df.to_csv(os.path.join(eval_dir, "metrics.csv"), index=False)
        self.scorer.metrics.reset()

        return metrics_df
    
    def evaluate_composition_task(self, eval_dataset_path, output_dir, device):
        """
        Specific evaluation loop for Task 1: Composition Conditioning.
        Reads a list of {'target_counts', 'num_frags'} and performs generation.
        """
        self._log.info(f"Starting Composition Task Evaluation using {eval_dataset_path}")
        os.makedirs(output_dir, exist_ok=True)
        
        # 1. Load the prepared dataset
        dataset = torch.load(eval_dataset_path)
        
        # 2. Manual Batching Loop
        batch_size = self._exp_conf.eval_batch_size
        num_samples = len(dataset)
        generated_z_list = []
        generated_pos_list = []
        all_targets_list = []
        
        # Sort by num_frags to minimize padding overhead (optional but good for speed)
        dataset.sort(key=lambda x: x['num_frags'])
        
        self._model.eval()
        
        for i in range(0, num_samples, batch_size):
            batch_items = dataset[i : i + batch_size]
            B = len(batch_items)
            
            # --- Construct Batch ---
            # 1. Determine max fragments in this batch
            max_len = max(item['num_frags'] for item in batch_items)
            
            # 2. Init Tensors
            frag_ids = torch.full((B, max_len), self._data_conf.vocab_size, dtype=torch.long, device=device) # Init with MASK
            frag_mask = torch.zeros((B, max_len), dtype=torch.float32, device=device)
            eval_target = torch.zeros((B, 5), dtype=torch.float32, device=device)
            condition_list = []
            
            for b_idx, item in enumerate(batch_items):
                n = item['num_frags']
                frag_mask[b_idx, :n] = 1.0
                target = item['target_counts'].to(device)
                eval_target[b_idx] = target
                # Normalize condition exactly as in training
                condition_list.append(target / 29.0)
                
            condition = torch.stack(condition_list)
            
            # 3. Sample Noise (Pure Generation)
            # We use sample_ref with impute=None to get pure noise
            noise_batch = self._flow_matcher.sample_ref(
                n_samples=B,
                n_fragments=max_len,
                device=device,
                flow_mask=frag_mask,
                impute_mask=None,
                impute=None, 
                impute_cat=None,
                as_tensor_7=True
            )
            
            # 4. Assemble Input Dict
            batch_feats = {
                "frag_ids": frag_ids, # Used for shape/type inference in model
                "frag_mask": frag_mask,
                "condition": condition,
                "rigids_0": None, # Not needed for pure generation
                **noise_batch # Adds 'rigids_t' and 'cat_t'
            }
            
            # 5. Run Inference
            infer_out = self.inference_fn(batch_feats)
            
            # 6. Collect Results
            final_rigids = infer_out["rigid_traj"][-1] # [B, N, 7]
            cat_ids_pred = infer_out['cat_ids'] # [B, N]
            
            batch_gen_data = []
            for b_idx in range(B):
                n = int(batch_items[b_idx]['num_frags'])
                mol_frag_ids = cat_ids_pred[b_idx][:n].astype(np.int64)
                mol_rigids = final_rigids[b_idx][:n]
                
                batch_gen_data.append({
                    "frag_ids": mol_frag_ids,
                    "rigids": mol_rigids
                })
                
            # 7. Reconstruct Atoms (needed to get Z for scoring)
            # reconstruct_atoms returns pos_list, z_list, bond_overrides
            pos_list, z_list, _ = reconstruct_atoms(batch_gen_data, self.fragment_library)
            
            generated_z_list.extend(z_list)
            generated_pos_list.extend(pos_list)
            all_targets_list.append(du.move_to_np(eval_target))
            
            if (i // batch_size) % 10 == 0:
                self._log.info(f"Processed {i}/{num_samples} samples...")

        # 3. Flatten Targets
        all_targets = np.concatenate(all_targets_list, axis=0)
        
        # 3.5. Save generated data and targets
        save_path = os.path.join(output_dir, "generated_composition_data.pt")
        torch.save({
            "generated_pos": generated_pos_list,
            "generated_z": generated_z_list,
            "targets": all_targets
        }, save_path)
        
        # 4. Score
        self._log.info("Computing Composition Match Rate...")
        results = self.scorer.score_composition(generated_z_list, all_targets)
        
        self._log.info(f"Final Task 1 Results: {results}")
        
        # Save results
        metrics_df = pd.DataFrame([results])
        metrics_df.to_csv(os.path.join(output_dir, "composition_metrics.csv"), index=False)
        return metrics_df
    
    def evaluate_structure_task(self, eval_dataset_path, output_dir, device):
        self._log.info(f"Starting Structure Task Evaluation using {eval_dataset_path}")
        os.makedirs(output_dir, exist_ok=True)
        
        # 1. Load Data
        dataset = torch.load(eval_dataset_path)
        
        # 2. Sort (This is safe now, because we carry GT inside the items)
        dataset.sort(key=lambda x: x['num_frags'])
        
        batch_size = self._exp_conf.eval_batch_size
        num_samples = len(dataset)
        
        # Aligned containers
        generated_pos_list = []
        generated_z_list = []
        true_pos_list = []
        true_z_list = []
        
        self._model.eval()
        
        for i in range(0, num_samples, batch_size):
            batch_items = dataset[i : i + batch_size]
            B = len(batch_items)
            
            # --- A. Prepare Batch & Condition ---
            max_len = max(item['num_frags'] for item in batch_items)
            frag_ids = torch.full((B, max_len), self._data_conf.vocab_size, dtype=torch.long, device=device)
            frag_mask = torch.zeros((B, max_len), dtype=torch.float32, device=device)
            
            condition_list = []
            gt_batch_raw = [] # To hold GT reconstruction info
            
            for b_idx, item in enumerate(batch_items):
                n = item['num_frags']
                frag_mask[b_idx, :n] = 1.0
                condition_list.append(item['target_fingerprint'].to(device))
                
                # Collect GT data for this specific sample
                gt_batch_raw.append({
                    'frag_ids': item['gt_frag_ids'],
                    'rots': item['gt_rots'],
                    'trans': item['gt_trans']
                })

            condition = torch.stack(condition_list)
            
            # --- B. Generation ---
            noise_batch = self._flow_matcher.sample_ref(
                n_samples=B,
                n_fragments=max_len,
                device=device,
                flow_mask=frag_mask,
                impute_mask=None,
                impute=None, 
                impute_cat=None,
                as_tensor_7=True
            )
            batch_feats = {
                "frag_ids": frag_ids, "frag_mask": frag_mask, "condition": condition,
                "rigids_0": None, **noise_batch
            }
            infer_out = self.inference_fn(batch_feats)
            
            # --- C. Process Generated Data ---
            final_rigids = infer_out["rigid_traj"][-1]
            cat_ids_pred = infer_out['cat_ids']
            
            batch_gen_data = []
            for b_idx in range(B):
                n = int(batch_items[b_idx]['num_frags'])
                batch_gen_data.append({
                    "frag_ids": cat_ids_pred[b_idx][:n].astype(np.int64),
                    "rigids": final_rigids[b_idx][:n]
                })
                
            # Reconstruct Generated
            batch_gen_pos, batch_gen_z, _ = reconstruct_atoms(batch_gen_data, self.fragment_library)
            generated_pos_list.extend(batch_gen_pos)
            generated_z_list.extend(batch_gen_z)
            
            # --- D. Process Ground Truth Data ---
            # We must convert rots/trans -> rigids -> atoms
            batch_gt_data = []
            for item in gt_batch_raw:
                # Center fragments (standard processing)
                trans = item['trans'].float()
                rots = item['rots'].float()
                center = trans.mean(dim=0, keepdim=True)
                trans_centered = trans - center
                
                # Convert to rigid Tensor7
                rigids_obj = assemble_rigid_mat(rots, trans_centered)
                
                batch_gt_data.append({
                    "frag_ids": item['frag_ids'].long(),
                    "rigids": rigids_obj.to_tensor_7()
                })
                
            # Reconstruct True
            batch_true_pos, batch_true_z, _ = reconstruct_atoms(batch_gt_data, self.fragment_library)
            true_pos_list.extend(batch_true_pos)
            true_z_list.extend(batch_true_z)

            if (i // batch_size) % 10 == 0:
                self._log.info(f"Processed {i}/{num_samples}...")

        # 3. Save Aligned Result
        save_path = os.path.join(output_dir, "generated_structure_data.pt")
        torch.save({
            "generated_pos": generated_pos_list,
            "generated_z": generated_z_list,
            "true_pos": true_pos_list,
            "true_z": true_z_list
        }, save_path)
        
        self._log.info("Evaluation done. Data saved.")
        return None

    def loss_fn(self, batch):
        """Computes loss and auxiliary data.

        Args:
            batch: Batched data.
            model_out: Output of model ran on batch.

        Returns:
            loss: Final training loss scalar.
            aux_data: Additional logging data.
        """
        # initialize self-conditioned features as none
        batch["sc_logits"] = None
        # self-conditioning with 50% probability during training if enabled
        if self._model_conf.use_self_conditioning and self._model.training and torch.rand(1) < self._model_conf.sc_training_prob:
            with torch.no_grad():
                out_pre = self.model(batch)
                batch["sc_logits"] = out_pre["cat_logits"].detach()
        model_out = self.model(batch)
        loss_mask = batch["frag_mask"]
        B, N = loss_mask.shape
        batch_loss_mask = torch.any(loss_mask, dim=-1)

        # Initialize losses to 0
        trans_loss = torch.zeros((B,), device=loss_mask.device)
        rot_loss = torch.zeros((B,), device=loss_mask.device)
        cat_loss = torch.zeros((B,), device=loss_mask.device)
        
        # --- 1. Translation Loss ---
        if self._fm_conf.flow_trans:
            gt_trans_u_t = batch["trans_vectorfield"]
            trans_vectorfield_scaling = batch["trans_vectorfield_scaling"]
            # Apply mask to prediction
            pred_trans_v_t = model_out["trans_vectorfield"] * loss_mask[..., None]
            trans_vectorfield_mse = (gt_trans_u_t - pred_trans_v_t) ** 2 * loss_mask[..., None]
            trans_loss = torch.sum(
                trans_vectorfield_mse / (trans_vectorfield_scaling ** 2 + 1e-10),
                dim=(-1, -2),
            ) / (loss_mask.sum(dim=-1) + 1e-10)
            trans_loss *= self._exp_conf.trans_loss_weight
        
        # --- 2. Rotation Loss ---
        if self._fm_conf.flow_rot:
            # Prepare Candidates [B, N, S, 3]
            gt_candidates = self.flow_matcher._so3_fm.compute_symmetric_target_vectors(
                rot_0=ru.Rigid.from_tensor_7(batch["rigids_0"]).get_rots().get_rot_mats(),
                rot_t=batch["rot_t"],
                t=batch["t"],
                symmetries=batch["symmetries"]
            )
            # Prepare Prediction [B, N, 3]
            pred_rot_mat = model_out["rot_vectorfield"].double()
            rot_t_double = batch["rot_t"].double()
            B, N, _, _ = pred_rot_mat.shape
            # Matrix -> Vector conversion
            pred_rot_mat_flat = rearrange(pred_rot_mat, "b n r c -> (b n) r c")
            rot_t_flat = rearrange(rot_t_double, "b n r c -> (b n) r c")
            pred_at_id_flat = pt_to_identity(rot_t_flat, pred_rot_mat_flat)
            pred_rot_vec_flat = hat_inv(pred_at_id_flat).float()
            pred_rot_vec = rearrange(pred_rot_vec_flat, "(b n) c -> b n c", b=B, n=N)
            # Apply residue mask to prediction
            pred_rot_vec = pred_rot_vec * loss_mask[..., None]
            # Setup arguments for SO3FM loss functions
            loss_args = {
                "gt_candidates": gt_candidates,
                "pred_vec": pred_rot_vec,
                "sym_mask": batch["sym_mask"],
                "scaling": batch["rot_vectorfield_scaling"],
                "separate_rot_loss": self._exp_conf.separate_rot_loss
            }
            # Select Style
            rot_style = self._exp_conf.rotation_loss_style
            if rot_style == "af3":
                # Align target to model prediction (Min-of-N)
                best_loss = self.flow_matcher._so3_fm.compute_loss_af3(**loss_args)
            elif rot_style == "geodiff":
                # Align target to noisy state at time t
                best_loss = self.flow_matcher._so3_fm.compute_loss_geodiff(**loss_args)
            elif rot_style == "naive":
                # Direct MSE against index-0 target
                best_loss = self.flow_matcher._so3_fm.compute_loss_naive(**loss_args)
            else:
                raise ValueError(f"Unknown rotation_loss_style: {rot_style}")
            # Safety check for NaNs/Infs from masked symmetries
            if torch.any(torch.isinf(best_loss) & loss_mask.bool()):
                print("All symmetries masked for some *present* fragments; cannot compute rotation loss.")
            best_loss = torch.nan_to_num(best_loss, posinf=0.0)
            # Reduce
            rot_loss = torch.sum(best_loss * loss_mask, dim=-1) / (loss_mask.sum(dim=-1) + 1e-10)
            rot_loss *= self._exp_conf.rot_loss_weight
        
        # --- 3. Categorical Loss ---
        if self._fm_conf.flow_cat:
            cat_loss = self.flow_matcher._cat_fm.compute_loss(model_out, batch)
            cat_loss = self._exp_conf.cat_loss_weight * cat_loss

        # Sum final loss
        final_loss = rot_loss + trans_loss + cat_loss

        def normalize_loss(x):
            return x.sum() / (batch_loss_mask.sum() + 1e-10)

        aux_data = {
            "batch_train_loss": final_loss,
            "batch_rot_loss": rot_loss,
            "batch_trans_loss": trans_loss,
            "batch_cat_loss": cat_loss,
            "total_loss": normalize_loss(final_loss),
            "rot_loss": normalize_loss(rot_loss),
            "trans_loss": normalize_loss(trans_loss),
            "cat_loss": normalize_loss(cat_loss),
            "examples_per_step": torch.tensor(B),
            "frag_length": torch.mean(torch.sum(loss_mask, dim=-1)),
        }

        self._aux_data_history.append(
            {"aux_data": aux_data, "model_out": model_out, "batch": batch}
        )
        
        assert final_loss.shape == (B,)
        assert batch_loss_mask.shape == (B,)
        return normalize_loss(final_loss), aux_data

    def _set_t_feats(self, feats, t, t_placeholder):
        feats["t"] = t * t_placeholder
        (
            rot_vectorfield_scaling,
            trans_vectorfield_scaling,
        ) = self.flow_matcher.vectorfield_scaling(t)
        feats["rot_vectorfield_scaling"] = rot_vectorfield_scaling * t_placeholder
        feats["trans_vectorfield_scaling"] = trans_vectorfield_scaling * t_placeholder
        return feats
    
    def inference_fn(
        self,
        data_init
    ):
        """
        Inference function.

        Args:
            data_init: Initial data values for sampling.
        """

        # Run reverse process.
        sample_feats = copy.deepcopy(data_init)
        device = sample_feats["rigids_t"].device
        
        # use inference scaling for SO(3)
        self.flow_matcher._so3_fm.training = False
        
        # force initial centering of fragments
        # Extract current translation
        init_rigids = ru.Rigid.from_tensor_7(sample_feats["rigids_t"])
        trans = init_rigids.get_trans() # [B, N, 3]
        # Calculate Center of Mass (masked)
        mask = sample_feats["frag_mask"][..., None] # [B, N, 1]
        com = torch.sum(trans * mask, dim=1, keepdim=True) / (torch.sum(mask, dim=1, keepdim=True) + 1e-6)
        # Subtract CoM
        new_trans = trans - com
        # mask the CoM subtraction in masked positions
        new_trans *= mask
        # Update rigids
        init_rigids_centered = ru.Rigid(
            rots=init_rigids.get_rots(),
            trans=new_trans
        )
        sample_feats["rigids_t"] = init_rigids_centered.to_tensor_7()
        
        batch_size = sample_feats["rigids_t"].shape[0]
        t_placeholder = torch.ones((batch_size,), device=device)
        num_t = self._data_conf.num_t
        min_t = self._data_conf.min_t
        reverse_steps = np.linspace(min_t, 1.0, num_t)[::-1]
        dt = reverse_steps[0] - reverse_steps[1]
        start_rigids_np = du.move_to_np(sample_feats["rigids_t"])
        all_rigids = [start_rigids_np]
        current_sc_logits = None # initialize self-conditioning logits (at t=1, no SC)
        with torch.no_grad():
            for t in reverse_steps:
                sample_feats = self._set_t_feats(sample_feats, t, t_placeholder)
                # inject self-conditioning logits
                sample_feats["sc_logits"] = current_sc_logits
                # explicit time for categorical flow
                sample_feats["t_cat"] = t * t_placeholder 
                model_out = self.model(sample_feats)
                # update self-conditioning logits for next step if enabled
                if self._model_conf.use_self_conditioning:
                    current_sc_logits = model_out["cat_logits"]
                rot_vectorfield = model_out["rot_vectorfield"]
                trans_vectorfield = model_out["trans_vectorfield"]
                flow_mask = sample_feats["frag_mask"]
                
                # ---- categorical handling ----
                cat_vectorfield = None
                s_t = None
                if self._fm_conf.flow_cat:
                    s_t = sample_feats["cat_t"]
                    cat_vectorfield = model_out["cat_logits"]
                    
                _, _, rigids_t, cat_t = self.flow_matcher.reverse(
                    rigid_t=ru.Rigid.from_tensor_7(sample_feats["rigids_t"]),
                    cat_t=s_t,
                    rot_vectorfield=rot_vectorfield,
                    trans_vectorfield=trans_vectorfield,
                    cat_vectorfield=cat_vectorfield,
                    flow_mask=flow_mask,
                    t=t,
                    dt=dt,
                )

                sample_feats["rigids_t"] = rigids_t.to_tensor_7().to(device)
                if t == reverse_steps[-1]:
                    all_rigids.append(du.move_to_np(rigids_t.to_tensor_7()))
                if self._fm_conf.flow_cat:
                    sample_feats["cat_t"] = cat_t
            
        # convert final categorical distribution to hard classes
        cat_final = None
        if self._fm_conf.flow_cat:
            # sample_feats["cat_t"] already contains discrete indices [B, N]
            cat_ids = sample_feats["cat_t"]
            # Robustness: If the reverse process stopped before t=0 (e.g. min_t=0.01),
            # some tokens might still be the MASK token. We force-unmask them using
            # the logits from the final step (model_out is from the last iteration).
            mask_token_idx = self._data_conf.vocab_size # Assumes MASK is at index V
            is_still_masked = (cat_ids == mask_token_idx)
            if is_still_masked.any():
                # model_out["cat_logits"] is [B, N, V] (prediction over valid classes)
                # Take argmax to get the most likely valid fragment
                best_guess = torch.argmax(model_out["cat_logits"], dim=-1) # [B, N]
                # Replace masks with best guess
                cat_ids = torch.where(is_still_masked, best_guess, cat_ids)
            cat_final = du.move_to_np(cat_ids)

        ret = {
            "rigid_traj": all_rigids,  # List of [B, N, 7]
        }
        if cat_final is not None:
            ret["cat_ids"] = cat_final
            
        # reset scaling
        self.flow_matcher._so3_fm.training = True

        return ret

@hydra.main(version_base=None, config_path="config/", config_name="base")
def run(conf: DictConfig) -> None:

    # Fixes bug in https://github.com/wandb/wandb/issues/1525
    os.environ["WANDB_START_METHOD"] = "thread"
    exp = Experiment(conf=conf)
    return exp.start_training()


if __name__ == "__main__":
    run()
