import pytest
import jax
import jax.numpy as jnp
import pickle
from pretrain.mlp import MLP, PretrainedPositionalEncoding, load_trained_state, precompute_space_group_operations
from pretrain.gen_data import get_wyckoff_positions
from utils.space_graphs import SpaceGraph
from itertools import combinations

@pytest.fixture
def pretrained_model():
    # Load the pretrained model and its configuration
    ckpt_dir = "/n/fs/cfs/crystal-fourier-transformer/mlp-ckpt/1026199_2"
    with open(f"{ckpt_dir}/config.pkl", "rb") as f:
        config = pickle.load(f)
    config.get('seed', 42)
    abc_combinations = SpaceGraph(1, config['fourier_dim']).get_nodelist()
    graphs = [SpaceGraph(i, embedding_dim=config['fourier_dim'], points=abc_combinations) for i in range(1, 231)]
    graphs_array = jnp.array([
            g.get_adjacency_matrix().toarray()
            for g in graphs
        ])
    mlp = MLP(config)
    trained_state = load_trained_state(ckpt_dir, mlp)
    pretrained_encoding = PretrainedPositionalEncoding(config, jnp.array(abc_combinations), graphs_array, trained_state)
    
    return pretrained_encoding, trained_state

def get_encoding(model, state, pos, lattice_vectors, space_group):
    return model.apply(
        {'params': state.params, 'batch_stats': state.batch_stats},
        pos, lattice_vectors, space_group,
    )

ALL_ROTATIONS, ALL_TRANSLATIONS = precompute_space_group_operations()
ALL_WYCKOFF = get_wyckoff_positions()

def compute_orbits(positions, space_group):
    rotations, translations = ALL_ROTATIONS[space_group - 1], ALL_TRANSLATIONS[space_group - 1]
    
    def compute_orbit(pos):
        pos = pos.reshape(3, -1)
        new_positions = jax.vmap(lambda r, t: r @ pos + t)(rotations, translations)
        return new_positions
    return jax.vmap(compute_orbit)(positions)

def get_orbit_distance(pos1, pos2, space_group_num, lattice_vectors):
    """Calculate orbit distance between two positions."""
    
    def pbc_distance(p1, p2, lattice):
        diff = p1 - p2
        diff = diff - jnp.floor(diff + 0.5)
        cart_diff = diff @ lattice
        return jnp.linalg.norm(cart_diff, axis=1)
    
    orbit2 = compute_orbits(pos2, space_group_num)
    orbit1 = compute_orbits(pos1, space_group_num)
    
    distances = jax.vmap(lambda p2: pbc_distance(orbit1[0], p2, lattice_vectors))(orbit2)
    return jnp.min(distances)

@pytest.mark.parametrize("space_group_num", [
    2,   # Triclinic
    18,  # Monoclinic
    46,  # Orthorhombic
    100,  # Tetragonal
    170, # Trigonal
    230  # Cubic
])
def test_wyckoff_site_invariance(pretrained_model, space_group_num):
    """Test that points in the same Wyckoff site have the same embeddings."""
    model, state = pretrained_model
    lattice_vectors = jnp.eye(3).reshape(1, 3, 3)
    space_group = jnp.array([space_group_num])
    
    wyckoff_positions = ALL_WYCKOFF[space_group_num - 1]
    
    print(f"Testing Wyckoff positions in space group {space_group_num}")
    print(f"Number of positions: {len(wyckoff_positions)}")
        
    # Test pairs of positions within the same Wyckoff site
    for i, (pos1, pos2) in enumerate(combinations(wyckoff_positions, 2)):
        pos1 = pos1.reshape(1, 1, -1)
        pos2 = pos2.reshape(1, 1, -1)
        
        encoding1 = get_encoding(model, state, pos1, lattice_vectors, space_group)
        encoding2 = get_encoding(model, state, pos2, lattice_vectors, space_group)
        
        orbit_dist = get_orbit_distance(pos1, pos2, space_group_num, lattice_vectors[0])
        embedding_dist = jnp.linalg.norm(encoding1 - encoding2)
        
        print(f"  Position pair {i}:")
        print(f"    Positions: {pos1[0,0]}, {pos2[0,0]}")
        print(f"    Orbit distance: {orbit_dist:.6f}")
        print(f"    Embedding distance: {embedding_dist:.6f}")
        assert jnp.allclose(orbit_dist, embedding_dist, atol=1e-3)

@pytest.mark.parametrize("space_group_num", [
    2,   # Triclinic
    10,  # Monoclinic
    37,  # Orthorhombic
    81,  # Tetragonal
    155, # Trigonal
    196  # Cubic
])
def test_random_points_distance_correlation(pretrained_model, space_group_num):
    """Test that embedding distances correlate with orbit distances for random points."""
    model, state = pretrained_model
    lattice_vectors = jnp.eye(3).reshape(1, 3, 3)
    space_group = jnp.array([space_group_num])
    
    # Generate random points in the unit cell
    key = jax.random.PRNGKey(42)
    num_points = 10
    random_points = jax.random.uniform(key, shape=(num_points, 3))
    
    print(f"\nTesting random points in space group {space_group_num}")
    
    # Test pairs of random points
    for i, pos1 in enumerate(random_points[:-1]):
        for pos2 in random_points[i+1:]:
            pos1 = pos1.reshape(1, 1, -1)
            pos2 = pos2.reshape(1, 1, -1)
            
            encoding1 = get_encoding(model, state, pos1, lattice_vectors, space_group)
            encoding2 = get_encoding(model, state, pos2, lattice_vectors, space_group)
            
            orbit_dist = get_orbit_distance(pos1, pos2, space_group_num, lattice_vectors[0])
            embedding_dist = jnp.linalg.norm(encoding1 - encoding2)
            
            print(f"\nComparing points:")
            print(f"  Positions: {pos1[0,0]}, {pos2[0,0]}")
            print(f"  Orbit distance: {orbit_dist:.6f}")
            print(f"  Embedding distance: {embedding_dist:.6f}")
            assert jnp.allclose(orbit_dist, embedding_dist, atol=1e-3)
