"""
Data handling and processing for MDShortcut molecular dynamics diffusion models.

This module provides classes and utilities for:
- Loading and processing atomic structure datasets in various formats (ExtXYZ, empty cells)
- Reading and merging material property data from JSON/YAML files
- Creating batches of material samples for training and inference
- Handling periodic boundary conditions and neighbor lists
- Property-based filtering and augmentation

The main classes are:
- MaterialDataset: PyTorch dataset for atomic structures and properties
- Sample: Individual material sample with atoms, positions, and properties
- Batch: Collection of samples for batch processing
- MaterialCollateFn: Custom collate function for DataLoader
"""
import json
import os
from random import random
from collections import Counter

import numpy as np
import torch
import yaml
from ase import Atoms
from ase.io import read
from torch.utils.data import Dataset

from models.modules.neighborlists import Neighborlist
from utils import positions_into_cell
from models.modules.bmp import formal_charges, inv_element_map

SAVE_ROOT_DIR = os.environ.get('SAVE_ROOT_DIR', 'cache')
save_dirs = {
    "infer": os.path.join(SAVE_ROOT_DIR, 'infer'),
    "models": os.path.join(SAVE_ROOT_DIR, 'models'),
    "log": os.path.join(SAVE_ROOT_DIR, 'log')
}


class MaterialDataset(Dataset):
    """PyTorch Dataset for loading and processing material samples with atomic structures and properties.
    
    This dataset class supports multiple input formats for both atomic structures and material
    properties, with options for filtering, augmentation, and density control through ghost atoms.
    
    Attributes:
        samples (list): List of Sample objects representing the dataset.
    """

    def __init__(self, atom_src, property_src, target_density=None, filter_charge_bal=False):
        """Initialize MaterialDataset with atomic structures and properties.
        
        Args:
            atom_src (dict): Atomic structure source configuration.
                type (str): Source type ('extxyz' or '3D_empty')
                params (dict): Type-specific parameters
            property_src (dict): Property source configuration.
                type (str): Property type ('file', 'files', 'augment', 'empty')
                params (dict): Type-specific parameters
            target_density (float, optional): Target atomic density for ghost atom addition.
                If specified, ghost atoms are added to reach this density.
            filter_charge_bal (bool, optional): If True, filter out non-charge-balanced samples.
                Defaults to False.
        """
        super().__init__()

        if atom_src['type'] == 'extxyz':
            atoms = self.read_extxyz_dataset(**atom_src['params'])
        elif atom_src['type'] == '3D_empty':
            atoms = self.make_empty_3d_atoms(**atom_src['params'])
        else:
            raise NotImplementedError(f'Unknown dataset type: {atom_src["type"]}')

        if property_src['type'] == 'file':
            props = self.read_property_file(**property_src['params'])
        elif property_src['type'] == 'files':
            props = self.read_property_files(**property_src['params'])
        elif property_src['type'] == 'augment':
            props = self.augment_properties(**property_src['params'])
        elif property_src['type'] == 'empty':
            props = [None] * len(atoms)
        else:
            raise NotImplementedError(f'Unknown property source: {property_src["type"]}')

        min_len = min(len(atoms), len(props))
        atoms = atoms[:min_len]
        props = props[:min_len]
        samples = [Sample.from_ase_atoms(atoms[i], properties=props[i]) for i in range(min_len)]

        if filter_charge_bal:
            samples = [sample for sample in samples if sample.is_charge_balanced()]
        if target_density is not None:
            samples = [self.add_ghost_atoms(sample, target_density) for sample in samples]

        self.samples = samples

    def __len__(self):
        """Return the number of samples in the dataset.
        
        Returns:
            int: Number of samples.
        """
        return len(self.samples)

    def __getitem__(self, index):
        """Get a sample by index.
        
        Args:
            index (int): Sample index.
            
        Returns:
            Sample: The sample at the specified index.
        """
        return self.samples[index]

    @staticmethod
    def read_extxyz_dataset(file):
        """Read atomic structures from Extended XYZ format file.
        
        Args:
            file (str): Path to the ExtXYZ file.
            
        Returns:
            list: List of ASE Atoms objects loaded from the file.
        """
        atoms = read(file, index=":")
        if not isinstance(atoms, list):
            atoms = [atoms]
        return atoms

    @staticmethod
    def make_empty_3d_atoms(lx, ly, lz, n_samples):
        """Create a dataset of empty 3D unit cells with specified dimensions.
        
        Creates empty ASE Atoms objects with specified cell dimensions and periodic
        boundary conditions. Useful for generating initial structures for diffusion.

        Args:
            lx (float): Length of the unit cell in x direction.
            ly (float): Length of the unit cell in y direction.
            lz (float): Length of the unit cell in z direction.
            n_samples (int): Number of empty samples to create.
            
        Returns:
            list: List of empty ASE Atoms objects with the specified cell dimensions.
        """
        atoms = []
        for _ in range(n_samples):
            elements = torch.zeros((0,), dtype=torch.long)
            positions = torch.zeros((0, 3), dtype=torch.float)

            # Create the lattice tensor
            lattice = torch.zeros((3, 3), dtype=torch.float)
            lattice[0, 0] = lx
            lattice[1, 1] = ly
            lattice[2, 2] = lz

            atoms.append(Atoms(
                numbers=elements,
                positions=positions,
                cell=lattice,
                pbc=(True, True, True)
            ))
        return atoms

    @staticmethod
    def read_property_file(file):
        """Read material properties from JSON or YAML file.
        
        Args:
            file (str): Path to the property file (.json, .yml, or .yaml).
            
        Returns:
            list: List of property dictionaries, one per sample.
            
        Raises:
            NotImplementedError: If file format is not supported.
        """
        if file.endswith('.json'):
            props = json.load(open(file, 'r'))
        elif file.endswith('.yml') or file.endswith('.yaml'):
            props = yaml.safe_load(open(file, 'r'))
        else:
            raise NotImplementedError(
                'The properties file has an unsupported format.')
        return props

    @staticmethod
    def read_property_files(files):
        """
        Read multiple property files and merge their contents.
        
        Args:
            files (list): List of file paths to read.
            
        Returns:
            list: List of merged property dictionaries. Each element corresponds to
                  one sample with properties merged from all files.
        """
        all_props_list = []
        for file in files:
            if file.endswith('.json'):
                props = json.load(open(file, 'r'))
            elif file.endswith('.yml') or file.endswith('.yaml'):
                props = yaml.safe_load(open(file, 'r'))
            else:
                raise NotImplementedError(
                    'The properties file has an unsupported format.')
            all_props_list.append(props)
        
        if not all_props_list:
            return []
        
        # Check that all files have the same number of samples
        n_samples = len(all_props_list[0])
        for i, props in enumerate(all_props_list[1:], 1):
            if len(props) != n_samples:
                raise ValueError("All property files must have the same number of samples.")
        
        # Merge properties for each sample
        merged_props = []
        for i in range(n_samples):
            sample_props = {}
            for props in all_props_list:
                sample_props.update(props[i])
            merged_props.append(sample_props)
        
        return merged_props

    @staticmethod
    def augment_properties(n_samples, aug_args):
        """Generate augmented properties for samples using specified strategies.
        
        Creates properties for each sample based on augmentation arguments that specify
        how each property should be generated (linear interpolation, fixed values, etc.).
        
        Args:
            n_samples (int): Number of samples to generate properties for.
            aug_args (list): List of property generation configurations. Each dict contains:
                prop_name (str): Name of the property
                mode (str): Generation mode ('linear', 'fixed', 'empty')
                For 'linear': min_val, max_val for interpolation
                For 'fixed': value for constant property
                For 'empty': creates None properties
                
        Returns:
            list: List of property dictionaries, one per sample.
        """
        props = []
        for n in range(n_samples):
            properties = {}
            for prop_args in aug_args:
                if prop_args['mode'] == 'linear':
                    # Linear interpolation between two values
                    properties[prop_args['prop_name']] = prop_args['min_val'] + n * (prop_args['max_val'] - prop_args['min_val']) / max(n_samples - 1, 1)
                elif prop_args['mode'] == 'fixed':
                    properties[prop_args['prop_name']] = prop_args['value']
                elif prop_args['mode'] == 'empty':
                    properties[prop_args['prop_name']] = None
            props.append(properties)
        return props

    def add_ghost_atoms(self, sample, target_density):
        """
        Add ghost atoms to a sample to reach the target density.

        Args:
            sample (Sample): The original sample.
            target_density (float): The target density to achieve.

        Returns:
            Sample: A new sample with ghost atoms added.
        """
        volume = torch.abs(torch.linalg.det(sample.lattice))
        current_num_atoms = sample.get_num_atoms()
        target_num_atoms = int(target_density * volume)
        num_ghost_atoms = max(0, target_num_atoms - current_num_atoms)

        if num_ghost_atoms == 0:
            return sample

        # Create ghost atoms
        ghost_elements = torch.zeros(num_ghost_atoms, dtype=torch.long)
        ghost_positions = torch.rand(num_ghost_atoms, 3) @ sample.lattice

        # Combine original and ghost atoms
        new_elements = torch.cat([sample.elements, ghost_elements])
        new_positions = torch.cat([sample.positions, ghost_positions])

        # Create a new sample with ghost atoms
        return Sample(
            elements=new_elements,
            positions=new_positions,
            lattice=sample.lattice,
            pbc=sample.pbc,
            properties=sample.properties
        )

    def get_atom_list(self):
        """Convert all samples to ASE Atoms objects.
        
        Returns:
            list: List of ASE Atoms objects for all samples in the dataset.
        """
        return [sample.to_ase_atoms() for sample in self.samples]


class Sample:
    """Represents a single 3D material sample with atomic structure and properties.
    
    A Sample encapsulates all information about a material structure including:
    - Atomic positions, elements, and unit cell
    - Periodic boundary conditions
    - Neighbor lists for efficient distance computations
    - Material properties for conditional generation
    - Element embeddings for neural network input
    
    Attributes:
        elements (torch.LongTensor): Atomic numbers, shape (n_atoms,)
        positions (torch.FloatTensor): Atomic positions, shape (n_atoms, 3)
        lattice (torch.FloatTensor): Unit cell matrix, shape (3, 3)
        pbc (tuple): Periodic boundary conditions for each axis
        neighborlist (Neighborlist): Neighbor list manager for distance computations
        element_emb (torch.FloatTensor): Element embeddings, shape (n_atoms, d_embed)
        properties (dict): Material properties for conditioning
    """

    def __init__(self, elements, positions, lattice, pbc=(True, True, True),
                 neighborlist=None, init_r_cut=None,
                 element_emb=None, properties=None):
        """
        Args:
            elements (torch.LongTensor): element numbers corresponding to the atoms, with shape (n_atom).
            positions (torch.FloatTensor): positions correspondng to the atoms, with shape (n_atom, 3).
            lattice (torch.FloatTensor): the unit cell size, with shape (3, 3).
            pbc (tuple, optional): periodic boundary conditions, set for each axis respectively.
            init_r_cut (float): cutoff used for initial neighbor list construction. Neighbors are pruned to r_cut, but a larger cutoff avoids frequent rebuild of the list during inference.
            element_emb (torch.FloatTensor): Embedding of the element
            properties (dict[str, torch.FloatTensor]): Dictionary containing any additional properties of the sample (can be used for conditioning the model)
        """
        # Geometry
        self.elements = elements
        self.positions = positions
        self.lattice = lattice
        self.pbc = pbc

        if neighborlist is None:
            self.neighborlist = Neighborlist(self.lattice, self.pbc, init_r_cut)
        else:
            self.neighborlist = neighborlist
            neighborlist.set_init_r_cut(init_r_cut)

        # Additional properties
        self.element_emb = element_emb
        self.properties = properties

    def to(self, device):
        """Move sample tensors to the specified device.
        
        Args:
            device (str or torch.device): Target device (e.g., 'cpu', 'cuda:0').
            
        Returns:
            Sample: New sample with all tensors moved to the specified device.
        """
        return Sample(
            self.elements.to(device),
            self.positions.to(device),
            self.lattice.to(device),
            self.pbc,
            self.neighborlist.to(device),
            self.neighborlist.init_r_cut,
            None if self.element_emb is None else self.element_emb.to(device),
            None if self.properties is None else {k: self.properties[k].to(device) if self.properties.get(k, None) is not None else None for k in self.properties})

    @staticmethod
    def from_ase_atoms(atoms: Atoms, properties=None):
        """Create Sample object from ASE Atoms object.
        
        Args:
            atoms (ase.Atoms): ASE Atoms object containing atomic structure.
            properties (dict, optional): Additional material properties to associate
                with the sample. Defaults to None.
        
        Returns:
            Sample: New Sample object created from the ASE Atoms.
        """
        return Sample(
            torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long),
            torch.tensor(atoms.get_positions()).float(),
            torch.tensor(np.array(atoms.get_cell(complete=True))).float(),
            atoms.get_pbc(),
            properties=None if properties is None else {p: torch.tensor(properties[p]) if properties[p] is not None else None for p in properties})

    def to_ase_atoms(self):
        """Create ASE Atoms object from this Sample.
        """
        return Atoms(
            numbers=self.elements.detach().cpu().numpy(),
            positions=self.positions.detach().cpu().numpy(),
            cell=self.lattice.detach().cpu().numpy(),
            pbc=self.pbc
        )

    @torch.no_grad()
    def update_attrs(self, lattice=None, positions=None, elements=None, element_emb=None, properties=None):
        """Create a new Sample based on updated attributes.
        """
        new_lattice = lattice if lattice is not None else self.lattice
        new_positions = positions if positions is not None else self.positions
        new_elements = elements if elements is not None else self.elements
        new_element_emb = element_emb if element_emb is not None else self.element_emb
        new_properties = properties if properties is not None else self.properties

        return Sample(
            elements=new_elements,
            positions=new_positions,
            lattice=new_lattice,
            pbc=self.pbc,
            neighborlist=self.neighborlist if lattice is None else None,
            init_r_cut=self.neighborlist.init_r_cut if lattice is None else None,
            element_emb=new_element_emb,
            properties=new_properties)

    def null_properties(self, properties_to_null=None):
        """Set some or all of the properties to null

        Args:
            properties_to_null (_type_, optional): Which properties to null, if None, all are nulled. Defaults to None.

        Returns:
            Structure: The structure with the nulled properties
        """
        if properties_to_null is None:
            return self.update_attrs(properties={x: None for x in self.properties})
        else:
            return self.update_attrs(
                properties={x: None if x in properties_to_null else self.properties[x] for x in self.properties})

    def set_init_r_cut(self, init_r_cut):
        self.neighborlist.set_init_r_cut(init_r_cut)

    def update_edges(self, r_cut):
        self.neighborlist.update(self.positions, r_cut)

    def get_positions(self):
        return self.positions

    def get_elements(self):
        return self.elements

    def get_element_emb(self):
        return self.element_emb

    def get_property_arr(self, prop_name, null_placeholder=None):
        """
        Generate a tensor containing the property prop_name.
        The array is extended to the number of atoms, so it can be used as edge feature for the model.
        Returns None, if the property has been nulled.
        If null_placeholder is provided, it is used in case the property is null.
        Additionally, a mask will be returned, that is true, in case of a null property.

        Args:
            prop_name (str): Name of the requested property
            null_placeholder (FloatTensor): A placeholder value, in case the property has been nulled.

        Returns:
            FloatTensor: The requested property, extended to the number of atoms
        """
        p = self.properties[prop_name]
        if null_placeholder is None:
            if p is None:
                return None
            else:
                return p.unsqueeze(0).expand(self.get_num_atoms(), -1)
        else:
            if p is None:
                p = null_placeholder
            null_mask = torch.tensor([p is None], device=self.positions.device).expand(self.get_num_atoms())
            return p.unsqueeze(0).expand(self.get_num_atoms(), -1), null_mask

    def get_num_atoms(self):
        return len(self.elements)

    def get_edges(self, r_cut):
        return self.neighborlist.get_edges(self.positions, r_cut)

    def randomize_uniform(self):
        """Generate a random Sample that follows uniform distribution.
        """
        x = torch.rand_like(self.positions)
        x = x @ self.lattice
        return self.update_attrs(positions=x)

    def back_to_cell(self):
        """Wraps out-bound atoms back to the cell.
        """
        return self.update_attrs(positions=positions_into_cell(self.positions, self.lattice))

    def rotate(self, R):
        """Rotate the sample by a rotation matrix R.
        """
        return self.update_attrs(positions=self.positions @ R.T, lattice=self.lattice @ R.T)

    def remove_mean(self, x):
        return x - torch.mean(x, dim=0).unsqueeze(0)

    def get_batch_size(self):
        return 1

    def get_batch_indices(self):
        return torch.zeros(self.get_num_atoms(), dtype=torch.long)

    def get_charge(self):
        ase_atoms = self.to_ase_atoms()
        ele_count = Counter(ase_atoms.get_chemical_symbols())
        total_charge = sum([formal_charges[inv_element_map[e]] * ele_count[e] for e in ele_count if e != 'X'])
        return total_charge

    def is_charge_balanced(self):
        """
        Returns:
            bool: True if sum of 'valence_electrons' == 0, else False.
                  Based on ASE reference data for each element's ground state.
        """
        return (self.get_charge() == 0)

    def cal_velocity(self, positions):
        """
        Calculate the velocity pointing from itself to the given positions with PBC.
        """
        delta = positions - self.positions
        frac_delta = torch.linalg.solve(self.lattice.T, delta.T).T
        frac_delta = frac_delta - torch.round(frac_delta)
        delta_pbc = torch.matmul(frac_delta, self.lattice)
        velocity = delta_pbc
        return velocity


class Batch:
    """Representing a batch of 3D material samples.
    """

    def __init__(self, samples):
        """
        Args:
            samples (list): each item is a Sample object.
        """
        self.samples = samples
        self.positions = torch.concat(
            [s.get_positions() for s in self.samples])
        self.elements = torch.concat([s.get_elements() for s in self.samples])
        element_embs = [s.get_element_emb() for s in self.samples]
        if not None in element_embs:
            self.element_emb = torch.concat(element_embs)  # (n_atoms, d_embed)
        else:
            self.element_emb = None

    def to(self, device):
        return Batch([x.to(device) for x in self.samples])

    def to_ase_atoms(self):
        """Create a list of ASE Atoms objects from this batch.
        """
        return [s.to_ase_atoms() for s in self.samples]

    def get_num_atoms(self):
        """Get the total number of atoms from all samples in this batch.
        """
        return sum([s.get_num_atoms() for s in self.samples])

    def get_positions(self):
        """Get the concatenated positions of all samples in this batch.

        Returns:
            torch.FloatTensor: concatenated positions of all samples, with shape (n_atoms, 3).
        """
        return self.positions

    def get_elements(self):
        return self.elements

    def get_element_emb(self):
        return self.element_emb

    def get_property_arr(self, prop_name, null_placeholder=None):
        if null_placeholder is None:
            return torch.concat([s.get_property_arr(prop_name) for s in self.samples])
        else:
            prop_arrs = [s.get_property_arr(prop_name, null_placeholder) for s in self.samples]
            return torch.concat([p for p, mask in prop_arrs]), torch.concat([mask for p, mask in prop_arrs])

    def set_init_r_cut(self, init_r_cut):
        for s in self.samples:
            s.set_init_r_cut(init_r_cut)

    def update_edges(self, r_cut):
        for sample in self.samples:
            sample.update_edges(r_cut)

    def get_edges(self, r_cut):
        """Get edges in all samples in this batch. The row and col indices have batch offsets built-in.
        """
        rows, cols, offsets = [], [], []
        atoms_i = 0
        for s in self.samples:
            r, c, o = s.get_edges(r_cut)
            rows.append(r + atoms_i)
            cols.append(c + atoms_i)
            offsets.append(o)
            atoms_i += s.get_num_atoms()
        return (torch.concat(rows), torch.concat(cols), torch.concat(offsets))

    def update_attrs(self, positions=None, elements=None, element_emb=None):
        """Update the attributes of all samples in this batch.
        The new attributes should have the same form as the original ones.
        """
        new_elements = elements if elements is not None else self.get_elements()
        new_element_emb = element_emb if element_emb is not None else self.get_element_emb()
        new_positions = positions if positions is not None else self.get_positions()

        atoms_i = 0
        new_samples = []
        for s in self.samples:
            last_atoms_i = atoms_i
            atoms_i = atoms_i + s.get_num_atoms()
            new_sample = s.update_attrs(
                elements=new_elements[last_atoms_i:atoms_i],
                positions=new_positions[last_atoms_i:atoms_i, :],
                element_emb=new_element_emb[last_atoms_i:atoms_i] if element_emb is not None else None)
            new_samples.append(new_sample)
        return Batch(new_samples)

    def null_properties(self, properties_to_null=None):
        return Batch([s.null_properties(properties_to_null) for s in self.samples])

    def remove_mean(self, x):
        mean = torch.zeros_like(x)
        batch_indices = self.get_batch_indices()
        for i in range(self.get_batch_size()):
            mask = batch_indices == i
            mean[mask, :] += torch.mean(x[mask, :], dim=0).unsqueeze(0)
        return x - mean

    def randomize_uniform(self):
        """Get a batch of random Samples that all follow the uniform distribution.
        """
        samples = [s.randomize_uniform() for s in self.samples]
        return Batch(samples)

    def get_batch_size(self):
        return len(self.samples)

    def get_batch_indices(self):
        return torch.concat(
            [torch.ones((s.get_num_atoms(),), dtype=torch.long) * i for i, s in enumerate(self.samples)])

    def rotate(self, R):
        return Batch([s.rotate(R) for s in self.samples])

    def get_charge(self):
        return [s.get_charge() for s in self.samples]

    def is_charge_balanced(self):
        """
        Returns:
            torch.BoolTensor of shape (batch_size,):
            True where each sample is charge-balanced, False otherwise.
        """
        mask_list = []
        for s in self.samples:
            mask_list.append(s.is_charge_balanced())  # returns Python bool
        return torch.tensor(mask_list, dtype=torch.bool)

    def get_sub_batch(self, sample_indices):
        """
        Create a new Batch containing only the samples at `sample_indices`.
        """
        new_samples = [self.samples[i] for i in sample_indices]
        return Batch(new_samples)

    def update_sub_batch(self, sample_indices, sub_batch):
        """
        Update samples in-place for the specified `sample_indices`.
        Then refresh the concatenated attributes.
        """
        for local_i, global_i in enumerate(sample_indices):
            self.samples[global_i] = sub_batch.samples[local_i]
        self._refresh_cache()

    def _refresh_cache(self):
        """
        Internal helper to rebuild self.positions, self.elements, self.element_emb
        from the updated self.samples list.
        """
        self.positions = torch.concat([s.get_positions() for s in self.samples], dim=0)
        self.elements = torch.concat([s.get_elements() for s in self.samples], dim=0)

        element_embs = [s.get_element_emb() for s in self.samples]
        if any(e is None for e in element_embs):
            self.element_emb = None
        else:
            self.element_emb = torch.concat(element_embs, dim=0)

    def cal_velocity(self, target_positions):
        """
        Calculate the velocity pointing from current positions to target positions with PBC
        for all samples in the batch.

        Args:
            target_positions (torch.Tensor): Target positions with shape (total_atoms, 3)

        Returns:
            torch.Tensor: Velocity vectors with shape (total_atoms, 3)
        """
        velocities = []
        atoms_i = 0

        for s in self.samples:
            last_atoms_i = atoms_i
            atoms_i = atoms_i + s.get_num_atoms()
            sample_target_positions = target_positions[last_atoms_i:atoms_i, :]
            sample_velocity = s.cal_velocity(sample_target_positions)
            velocities.append(sample_velocity)

        return torch.concat(velocities, dim=0)


class MaterialCollateFn:
    """Custom collate function for DataLoader to create batches of material samples.
    
    This collate function takes a list of Sample objects and combines them into a
    Batch object suitable for neural network processing, with proper device placement.
    
    Attributes:
        device (str): Target device for batch tensors.
    """
    
    def __init__(self, device):
        """Initialize the collate function with target device.
        
        Args:
            device (str): Device to move batch tensors to (e.g., 'cpu', 'cuda:0').
        """
        self.device = device

    def __call__(self, sample_batch):
        """Create a batch from a list of samples.

        Args:
            sample_batch (list): List of Sample objects to batch together.

        Returns:
            Batch: A batch of samples moved to the target device.
        """
        return Batch(sample_batch).to(self.device)
