"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import glob
import argparse
import torch
import random
import warnings
import numpy as np
import pandas as pd
from pymatgen.core.structure import Structure
from pathlib import Path

from dataclasses import dataclass
from torch.utils.data import Dataset

def get_crystal_string(cif_str, permute_atoms=False):
    structure = Structure.from_str(cif_str, fmt="cif") # adding the sort_structure bit

    # Randomly translate within the unit cell
    structure.translate_sites(
        indices=range(len(structure.sites)), vector=np.random.uniform(size=(3,))
    )

    lengths = structure.lattice.parameters[:3]
    angles = structure.lattice.parameters[3:]
    atom_ids = structure.species
    frac_coords = structure.frac_coords

    if permute_atoms:
        n_atoms = len(atom_ids)
        perm = np.random.permutation(n_atoms)
        atom_ids = [atom_ids[i] for i in perm]
        frac_coords = frac_coords[perm]

    crystal_str = (
        " ".join(["{0:.1f}".format(x) for x in lengths]) + "\n" +
        " ".join([str(int(x)) for x in angles]) + "\n" +
        "\n".join([
            str(t) + "\n" + " ".join([
                "{0:.2f}".format(x) for x in c
            ]) for t, c in zip(atom_ids, frac_coords)
        ])
    )

    return crystal_str

class CifDataset(Dataset):
    def __init__(
        self,
        csv_fn,
        format_options={},
        w_attributes=False,
    ):
        super().__init__()

        if not os.path.exists(csv_fn) and not glob.glob(csv_fn):
            raise ValueError(f"CSV file {csv_fn} does not exist")

        df = pd.concat([pd.read_csv(fn) for fn in glob.glob(csv_fn)])
        self.inputs = df.to_dict(orient="records")

        self.format_options = format_options
        self.w_attributes = w_attributes
   
    def crystal_string(self, input_dict):
        k = 'cif' if 'cif' in input_dict else 'cif_str'
        return get_crystal_string(input_dict[k])

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        if not 0 <= index < len(self):
            raise IndexError("Index out of range")

        vals = self.inputs[index]
        crystal_str = self.crystal_string(vals)
        return crystal_str

def setup_datasets(args):    
    format_options = {
        "permute_composition": args.format_permute_composition,
        "permute_structure": args.format_permute_structure,
    }

    datasets = {
        "train": CifDataset(
            str(args.data_path / "train.csv"), 
            format_options,
            w_attributes=args.w_attributes,
        ),
        "val": CifDataset(
            str(args.data_path / "val.csv"),
            format_options,
            w_attributes=args.w_attributes,
        ),
    }

    return datasets

def main(args):
    datasets = setup_datasets(args)
    
    # Example: print first few crystal strings
    print("First 3 training examples:")
    for i in range(min(3, len(datasets["train"]))):
        print(f"Example {i}:")
        print(datasets["train"][i])
        print("-" * 50)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--expdir", type=Path, default="exp")
    parser.add_argument("--data-path", type=Path, default="data/basic")
    parser.add_argument("--format-permute-composition", action="store_true", default=False)
    parser.add_argument("--format-permute-structure", action="store_true", default=False)
    parser.add_argument("--w-attributes", type=int, default=1)
    parser.add_argument("--debug", action="store_true", default=False)
    args = parser.parse_args()

    main(args)
