import argparse
import os
from pathlib import Path
import time

BASE_DIR = Path(__file__).resolve().parent.parent.parent

# Parse minimal args and set environment before heavy imports
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()

if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

import json
from typing import Any

from chemutils.datasets import pepsol
from chemutils.models import mace
from chemtrain import trainers
from chemtrain.data import preprocessing

import numpy as onp
from jax import numpy as jnp, random, tree_util
from jax_md import partition, space
import optax

def parse_args() -> argparse.Namespace:
    """
    Parse full command-line arguments for training.

    Returns
    -------
    argparse.Namespace
        Parsed arguments with 'device'.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--device',
        type=str,
        help='GPU or MIG UUID to use for training'
    )
    parser.add_argument(
        '--task',
        type=str,
        help='Manual task ID.'
    )
    return parser.parse_args()

def load_and_split_data(
    data_path: str,
    scale_R: float = 1.0,
    scale_U: float = 1.0,
    fractional: bool = True,
    train_frac: float = 1.0,
    seed: int = 11
) -> dict[str, dict[str, Any]]:
    """
    Load and split NPZ dataset into train/validation/test, with scaling.

    Parameters
    ----------
    data_path : str
        Path to NPZ file.
    scale_R : float
        Coordinate scaling factor.
    scale_U : float
        Energy scaling factor.
    fractional : bool
        Use fractional coordinates.
    train_frac : float
        Fraction of training data to keep.
    seed : int
        RNG seed for splitting.

    Returns
    -------
    dict[str, dict[str, Any]]
        Split datasets with keys 'training', 'validation', 'testing'.
    """
    raw = onp.load(data_path, allow_pickle=True)
    data = dict(raw)

    train_data, val_data, test_data = preprocessing.train_val_test_split(
        data, shuffle=True, shuffle_seed=seed
    )
    splits = {'training': train_data, 'validation': val_data, 'testing': test_data}

    for key, subset in splits.items():
        splits[key] = pepsol.scale_dataset(
            subset, scale_R=scale_R, scale_U=scale_U, fractional=fractional
        )

    # Reduce training size
    n_train = splits['training']['R'].shape[0]
    keep = int(train_frac * n_train)
    for field in splits['training']:
        splits['training'][field] = splits['training'][field][:keep]

    return splits


def setup_neighborlist(
    dataset_split: dict[str, Any],
    cutoff: float,
    box: jnp.ndarray,
    batch_size: int,
    fractional: bool = True
) -> tuple[Any, tuple[int, int, float]]:
    """
    Allocate neighbor list for the training split.

    Parameters
    ----------
    dataset_split : dict[str, Any]
        Dataset dict with 'R' and 'mask'.
    cutoff : float
        Neighbor cutoff radius.
    box : jnp.ndarray
        Simulation box.
    batch_size : int
        Batch size for preallocation.
    fractional : bool
        Use fractional coordinates.

    Returns
    -------
    tuple containing neighbor_fn and stats (max_neighbors, max_edges, avg_neighbors).
    """
    disp_fn, _ = space.periodic_general(box=box, fractional_coordinates=fractional)
    return preprocessing.allocate_neighborlist(
        dataset_split,
        disp_fn,
        box,
        r_cutoff=cutoff,
        mask_key='mask',
        box_key='box',
        format=partition.Sparse,
        batch_size=batch_size
    )


def build_model(
    displacement_fn: Any,
    hidden_irreps: str,
    readout_irreps: str,
    output_irreps: str,
    max_ell: int,
    num_interactions: int,
    correlation: int,
    cutoff: float,
    n_species: int,
    max_edges: int,
    avg_neighbors: float,
    seed: int = 0
) -> tuple[Any, Any]:
    """
    Build MACE model initializer and energy function.

    Returns
    -------
    tuple containing init_fn and energy_fn.
    """
    init_fn, gnn_energy_fn = mace.mace_neighborlist_pp(
        displacement_fn,
        r_cutoff=cutoff,
        n_species=n_species,
        max_edges=max_edges,
        per_particle=False,
        avg_num_neighbors=avg_neighbors,
        mode='energy',
        hidden_irreps=hidden_irreps,
        max_ell=max_ell,
        num_interactions=num_interactions,
        correlation=correlation,
        readout_mlp_irreps=readout_irreps,
        output_irreps=output_irreps,
    )

    def energy_fn_template(energy_params):
        def energy_fn(pos, neighbor, mode=None, **dynamic_kwargs):
            assert 'species' in dynamic_kwargs.keys(), 'species not in dynamic_kwargs'

            if "mask" not in dynamic_kwargs:
                print(f"Add defaul all-positive mask.")
                dynamic_kwargs["mask"] = jnp.ones(pos.shape[0], dtype=jnp.bool_)

            if "box" in dynamic_kwargs:
                print(f"Found box in energy kwargs")

            return gnn_energy_fn(
                energy_params, pos, neighbor, **dynamic_kwargs
            )
        return energy_fn

    return init_fn, energy_fn_template


def train_model(
    init_params: Any,
    energy_fn: Any,
    neighbor_fn: Any,
    dataset: dict[str, dict[str, Any]],
    output_dir: str,
    batch_size: int,
    num_epochs: int,
    init_lr: float,
    decay_rate: float
) -> trainers.ForceMatching:
    """
    Train the model with force matching.

    Returns
    -------
    trainers.ForceMatching
    """
    num_samples = dataset['training']['R'].shape[0]
    total_steps = (num_epochs * num_samples) // batch_size
    scheduler = optax.exponential_decay(
        init_value=init_lr,
        transition_steps=total_steps,
        decay_rate=decay_rate
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.scale_by_adam(),
        optax.scale_by_schedule(scheduler),
        optax.scale(-1.0)
    )

    trainer = trainers.ForceMatching(
        init_params,
        optimizer,
        energy_fn,
        neighbor_fn,
        log_file=os.path.join(output_dir, 'force_matching.log'),
        batch_per_device=batch_size,
    )

    return trainer

def main(task_id):
    """
    Execute training pipeline end-to-end.
    """

    MACE_CONFIG = {
        "hidden_irreps": "32x0e+32x1o", 
        "readout_mlp_irreps": "16x0e",  # Default
        "output_irreps": "1x0e", # Energy
        "max_ell": 3,
        "num_interactions": 2,
        "correlation": 3, # 3 or 2 
        "r_cutoff": 0.5,
        "mol": "ala2",
        "CG_map": "heavyOnly",
        "type": "CG",  # "AT" or "CG"
        "PRNGKey_seed": 1,
        "data_path": "",
        "train_frac": 1.0 ,   # 500k*0.2 = 100k
    }

    TRAIN_CONFIG = {
        "batch_size": 32,
        "num_epochs": 100,
        "init_lr": 0.001,
        "decay_rate": 0.9,  # Decay rate for learning rate
        "optimizer": "adam+decay",
    }

    MACE_CONFIG['output_dir'] = ()

    dataset = load_and_split_data(
        MACE_CONFIG['data_path'],
        scale_R=1,
        scale_U=1,
        fractional=True,
        train_frac=MACE_CONFIG['train_frac'],
        seed=MACE_CONFIG['PRNGKey_seed']
    )

    # Neighbor list
    box = jnp.asarray(dataset['training']['box'][0])
    neighbor_fn, stats = setup_neighborlist(
        dataset['training'],
        cutoff=MACE_CONFIG['r_cutoff'],
        box=box,
        batch_size=TRAIN_CONFIG['batch_size']
    )

    # Model init
    n_species = len(dataset['training']['species'][0])
    init_fn, energy_fn_template = build_model(
        displacement_fn=space.periodic_general(box=box, fractional_coordinates=True)[0],
        hidden_irreps=MACE_CONFIG['hidden_irreps'],
        readout_irreps=MACE_CONFIG['readout_mlp_irreps'],
        output_irreps=MACE_CONFIG['output_irreps'],
        max_ell=MACE_CONFIG['max_ell'],
        num_interactions=MACE_CONFIG['num_interactions'],
        correlation=MACE_CONFIG['correlation'],
        cutoff=MACE_CONFIG['r_cutoff'],
        n_species=n_species,
        max_edges=stats[1],
        avg_neighbors=stats[2],
        seed=MACE_CONFIG['PRNGKey_seed']
    )

    key = random.PRNGKey(MACE_CONFIG['PRNGKey_seed'])
    init_params = init_fn(
        key,
        jnp.asarray(dataset['training']['R'][0]),
        neighbor_fn,
        species=jnp.asarray(dataset['training']['species'][0]),
        mask=jnp.asarray(dataset['training']['mask'][0])
    )
    print(dataset['testing'].keys())
    # Train
    trainer = train_model(
        init_params,
        energy_fn_template,
        neighbor_fn,
        dataset,
        output_dir=MACE_CONFIG['output_dir'],
        batch_size=TRAIN_CONFIG['batch_size'],
        num_epochs=TRAIN_CONFIG['num_epochs'],
        init_lr=TRAIN_CONFIG['init_lr'],
        decay_rate=TRAIN_CONFIG['decay_rate']
    )

    raw_data = onp.load(f"{BASE_DIR}/output/ala2/{task_id}/proposed_samples.npz", allow_pickle=True)
    raw_data = dict(raw_data)
    raw_data['R'] = raw_data['R'].reshape(raw_data['R'].shape[0], -1, 3)

    sample = {**raw_data}
    sample['R'] = raw_data['R']/ raw_data['box'][0,0,0]
    sample['species'] = jnp.tile(dataset['training']['species'][0], (raw_data['R'].shape[0],1))
    sample['U'] = jnp.zeros((raw_data['R'].shape[0],))

    if "atom6" in task_id:
        energy_params = onp.load("", allow_pickle=True)
    else:
        energy_params = onp.load("", allow_pickle=True)
        energy_params = energy_params['params']

    energy_params = tree_util.tree_map(jnp.asarray, energy_params)

    start_time = time.time()
    energies =  trainer.predict(sample, energy_params, batch_size=500)
    end_time = time.time()
    prediction_time = end_time - start_time
    
    print("Energies shape:", energies['U'].shape)
    if 'logp' in raw_data:
        print("logp shape:", raw_data['logp'].shape)
    else:
        print("'logp' key not found in raw data.")

    output_path = BASE_DIR / f"output/ala2/{task_id}"
    os.makedirs(output_path, exist_ok=True)
    onp.savez(output_path / "predicted_energy.npz", R=raw_data['R'], U=energies['U'], logp=raw_data['logp'] if 'logp' in raw_data else None)
    with open(output_path / "run.log", "a") as f:
        f.write(f"Energy Evaluation time: {prediction_time} seconds\n")

if __name__ == '__main__':
    main(args.task)