import pytest
import jax
import jax.numpy as jnp
import json
from flax.training import checkpoints
from model import CrystalFourierTransformer

@pytest.fixture
def trained_model():
    load_dir = '/n/fs/cfs/crystal-fourier-transformer/checkpoints/job_21595639_0'
    with open(f"{load_dir}/config.json", 'r') as f:
        config = json.load(f)
    
    # Load abc_combinations and graphs_array
    abc_combinations = jnp.load(f"{load_dir}/abc_combinations.npy")
    graphs_array = jnp.load(f"{load_dir}/graphs_array.npy")
    
    # Initialize model
    cft = CrystalFourierTransformer(config, jnp.array(abc_combinations), jnp.array(graphs_array))
    state = checkpoints.restore_checkpoint(load_dir, target=None)
    
    return cft, state['params']

def get_encoding(model, params, pos, reciprocal_matrices, space_group):
    return model.apply(
        {'params': params["params"]},
        pos, reciprocal_matrices, space_group,
        method=model.get_positional_encoding
    )

def test_space_group_2_symmetry(trained_model):
    model, params = trained_model
    space_group = jnp.array([2])
    pos1 = jnp.array([(0.25, 0.25, 0.25)]).reshape(1, 1, -1)
    pos2 = jnp.array([(1.25, 1.25, 1.25)]).reshape(1, 1, -1)
    reciprocal_matrices = jnp.identity(3).reshape(1, 3, 3)
    
    encoding1 = get_encoding(model, params, pos1, reciprocal_matrices, space_group)
    encoding2 = get_encoding(model, params, pos2, reciprocal_matrices, space_group)
    
    print("Encoding 1:", encoding1)
    print("Encoding 2:", encoding2)
    
    assert jnp.allclose(encoding1, encoding2, atol=1e-3)

def test_space_group_166_symmetry(trained_model):
    model, params = trained_model
    space_group = jnp.array([166])
    
    pos1 = jnp.array([(1.0/3, 2.0/3, 2.0/3)]).reshape(1, 1, -1)
    pos2 = jnp.array([(2.0/3, 1.0/3, 1.0/3)]).reshape(1, 1, -1)
    pos3 = jnp.array([(4.0/3, 2.0/3, 2.0/3)]).reshape(1, 1, -1)
    reciprocal_matrices = jnp.identity(3).reshape(1, 3, 3)
    
    encodings = []
    for pos in [pos1, pos2, pos3]:
        encoding = get_encoding(model, params, pos, reciprocal_matrices, space_group)
        encodings.append(encoding)
        print(f"Encoding for position {pos.reshape(-1)}:", encoding)
    
    # Check that all encodings are close to each other
    for i in range(len(encodings)):
        for j in range(i+1, len(encodings)):
            assert jnp.allclose(encodings[i], encodings[j], atol=1e-3)