import jax
import jax.numpy as jnp
import json
import optax
import time
import wandb
import pickle
import argparse
from functools import partial
from model import CrystalFourierTransformer
from pretrain.mlp import load_trained_state, MLP
from utils.data_processing import prepare_data
from utils.space_graphs import SpaceGraph
from jax import random
from flax.training import train_state, checkpoints
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(description="Train the model with custom parameters")
    
    # Model parameters
    parser.add_argument('--embedding_dim', type=int, default=256, help="Dimension of input embedding")
    parser.add_argument('--num_attn_blocks', type=int, default=4, help="Number of attention blocks in the transformer")
    parser.add_argument('--num_heads', type=int, default=4, help="Number of attention heads in each attention block")
    parser.add_argument('--ff_dim', type=int, default=512, help="Dimension of projection in attention block MLP")
    parser.add_argument('--encoding_hidden_dim', type=int, default=256, help="Dimension of space group-dependent hidden layer in positional encoding MLP")
    parser.add_argument('--final_hidden_1', type=int, default=256, help="Dimension of hidden layer 1 in final MLP")
    parser.add_argument('--final_hidden_2', type=int, default=32, help="Dimension of hidden layer 2 in final MLP")
    parser.add_argument('--fourier', action='store_true', help="Whether to use Fourier positional encoding")
    parser.add_argument('--log', type=str, default='')
    parser.add_argument('--l2_reg', type=float, default=0.001, help="L2 regularization strength")

    # Training parameters
    parser.add_argument('--dropout_rate', type=float, default=0.0)
    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('--learning_rate', type=float, default=0.0003)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--num_epochs', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--wandb_project', type=str, default='cft-formation6')
    parser.add_argument('--ckpt_dir', type=str, default='/n/fs/cfs/crystal-fourier-transformer/checkpoints2')
    parser.add_argument('--data_dir', type=str, default='/n/fs/cfs/cgcnn-old/cgcnn/data/mp-all')
    parser.add_argument('--seed', type=int, default=42)
    
    args = parser.parse_args()
    return vars(args)

def create_learning_rate_fn(config, num_train_examples):
    """Create a cosine decay learning rate schedule with linear warmup."""
    steps_per_epoch = num_train_examples // config['batch_size']
    total_steps = steps_per_epoch * config['num_epochs']
    warmup_steps = config['warmup_steps']

    warmup_fn = optax.linear_schedule(
        init_value=0.1 * config['learning_rate'],
        end_value=config['learning_rate'],
        transition_steps=warmup_steps
    )
    cosine_decay_fn = optax.cosine_decay_schedule(
        init_value=config['learning_rate'],
        decay_steps=total_steps - warmup_steps,
        alpha=1e-6
    )
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_decay_fn],
        boundaries=[warmup_steps]
    )

    return schedule_fn

def save_model_components(config, save_dir):
    with open(f"{save_dir}/config.json", 'w') as f:
        json.dump(config, f)

@partial(jax.jit, static_argnums=(0,))
def train_step(apply_fn, state, batch, dropout_rng):
    def loss_fn(params):
        variables = {'params': params, 'batch_stats': state.batch_stats}
        predictions, new_model_state = apply_fn(
            variables,
            batch['atom_numbers'], 
            batch['positions'], 
            batch['lattice_matrices'],
            batch['space_groups'],
            batch['masks'],
            training=True,
            rngs={'dropout': dropout_rng},
            mutable=['batch_stats']  # Mark batch_stats as mutable
        )
        predictions = jnp.squeeze(predictions)
        mse_loss = jnp.mean((predictions - batch['targets']) ** 2)
        
        # Add L2 regularization
        l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
        total_loss = mse_loss + 0.0002 * l2_loss
        
        return total_loss, (predictions, new_model_state)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (predictions, new_model_state)), grads = grad_fn(state.params)
    
    # Update both params and batch_stats
    state = state.apply_gradients(
        grads=grads,
        batch_stats=new_model_state['batch_stats']
    )
    
    metrics = {'loss': loss, 'mae': jnp.mean(jnp.abs(predictions - batch['targets']))}
    return state, metrics

@partial(jax.jit, static_argnums=(0,))
def eval_step(apply_fn, params, batch_stats, batch):
    variables = {'params': params, 'batch_stats': batch_stats}
    predictions = apply_fn(
        variables,
        batch['atom_numbers'], 
        batch['positions'], 
        batch['lattice_matrices'],
        batch['space_groups'],
        batch['masks'],
        training=False,  # Use False for evaluation
        mutable=False  # Don't update batch_stats during evaluation
    )
    predictions = jnp.squeeze(predictions)
    loss = jnp.mean((predictions - batch['targets']) ** 2)
    mae = jnp.mean(jnp.abs(predictions - batch['targets']))
    return loss, mae

def create_train_state(config, cft, train_features, key):
    key, init_key = random.split(key)
    
    # Initialize both params and batch_stats
    variables = cft.init(
        {'params': init_key},  # Note: explicitly naming the 'params' collection
        train_features['atom_numbers'][:1], 
        train_features['positions'][:1],
        train_features['lattice_matrices'][:1],
        train_features['space_groups'][:1],
        train_features['masks'][:1],
        training=False  # Use False for initialization
    )
    
    # Extract params and batch_stats
    params = variables['params']
    batch_stats = variables.get('batch_stats', {})  # Get batch_stats if present

    learning_rate_fn = create_learning_rate_fn(config, len(train_features['atom_numbers']))
    optimizer = optax.chain(
        optax.clip_by_global_norm(config['max_grad_norm']),
        optax.adamw(learning_rate=learning_rate_fn, weight_decay=config['weight_decay'])
    )
    
    # Create a new TrainState class that includes batch_stats
    class TrainStateWithBatchStats(train_state.TrainState):
        batch_stats: dict

    return TrainStateWithBatchStats.create(
        apply_fn=cft.apply,
        params=params,
        tx=optimizer,
        batch_stats=batch_stats
    )

def load_pretrained_state(checkpoint_dir):
    with open(f"{checkpoint_dir}/config.pkl", "rb") as f:
        config = pickle.load(f)
    mlp = MLP(config)
    state = load_trained_state(checkpoint_dir, mlp)
    return config, state

def preprocess_data_batch(cft, state, features, batch_size=64, tolerance=0.0001):
    """Preprocess each crystal to remove equivalent atoms."""
    total_samples = features['atom_numbers'].shape[0]

    @jax.jit
    def process_batch(params, atom_numbers, positions, lattice_matrices, space_groups, masks):
        def get_embeddings(atom_numbers, positions, lattice_matrices, space_groups):
            atom_embeddings = cft.apply(params, atom_numbers, method=cft.atom_embedding)
            positional_encodings = cft.apply(params, positions, lattice_matrices, space_groups, method=cft.positional_encoding)
            return atom_embeddings + positional_encodings

        x = get_embeddings(atom_numbers, positions, lattice_matrices, space_groups)

        @jax.vmap
        def process_single_crystal(x_crystal, mask_crystal):
            valid_indices = jnp.where(mask_crystal, size=444)[0]
            valid_atoms = x_crystal[valid_indices]
            pairwise_distances = jnp.linalg.norm(
                valid_atoms[:, None, :] - valid_atoms[None, :, :], axis=-1
            )
            within_tolerance = pairwise_distances < tolerance
            keep_indices = jnp.argmax(within_tolerance, axis=1)
            unique_mask = jnp.zeros(444, dtype=bool).at[valid_indices[keep_indices]].set(1)
            jax.debug.print("num non-unique atoms: {diff}", diff = jnp.sum(mask_crystal) - jnp.sum(unique_mask))
            return unique_mask

        unique_masks = process_single_crystal(x, masks)
        return unique_masks

    processed_features = {k: [] for k in features.keys()}

    for i in range(0, total_samples, batch_size):
        batch = {k: v[i:i+batch_size] for k, v in features.items()}
        unique_masks = process_batch(state.params, **batch)
        for k, v in batch.items():
            if k == 'masks':
                processed_features[k].append(unique_masks)
            else:
                processed_features[k].append(v)
    return {k: jnp.concatenate(v, axis=0) for k, v in processed_features.items()}

def main():
    config = parse_args()
    # if config['fourier']:
    #     encoding_config, pretrained_state = load_pretrained_state()
    #     graph_dim = 150
    #     #graph_dim = encoding_config['fourier_dim']
    # else:
    #     graph_dim = config['embedding_dim']
    #     encoding_config = None
    #     pretrained_state = None

    # Precompute lattice points and adjacency matrix for each space group
    # abc_combinations = SpaceGraph(1, graph_dim).get_nodelist()
    # graphs = [SpaceGraph(i, embedding_dim=graph_dim, points=abc_combinations) for i in range(1, 231)]
    # graphs_array = jnp.array([
    #         g.get_adjacency_matrix().toarray()
    #         for g in graphs
    #     ])

    # Load pretrained components here
    cubic_adj_matrices = np.load("/n/fs/cfs/crystal-fourier-transformer/data/adjacency_matrices_320.npz")
    cubic_adj_matrices = jnp.array(cubic_adj_matrices['matrices'])
    hexagonal_adj_matrices = np.load("/n/fs/cfs/crystal-fourier-transformer/data/adjacency_matrices_600.npz")
    hexagonal_adj_matrices = jnp.array(hexagonal_adj_matrices['matrices'])
    
    cubic_abc_combinations = jnp.array(SpaceGraph(1, 320).get_nodelist())
    hexagonal_abc_combinations = jnp.array(SpaceGraph(168, 600).get_nodelist())
    
    cubic_encoding_config, cubic_pretrained_state = load_pretrained_state("/n/fs/cfs/crystal-fourier-transformer/mlp-ckpt/2184867_2")
    hexagonal_encoding_config, hexagonal_pretrained_state = load_pretrained_state("/n/fs/cfs/crystal-fourier-transformer/mlp-ckpt/2177644_2")

    key = random.PRNGKey(config['seed'])
    features, targets = prepare_data(config['data_dir'])

    key, shuffle_key = random.split(key)
    shuffled_indices = random.permutation(shuffle_key, len(targets))
    train_size = int(0.8 * len(targets))
    val_size = int(0.1 * len(targets))
    
    train_indices = shuffled_indices[:train_size]
    val_indices = shuffled_indices[train_size:train_size+val_size]
    test_indices = shuffled_indices[train_size+val_size:]

    train_data = (
        {k: v[train_indices] for k, v in features.items()},
        targets[train_indices]
    )
    val_data = (
        {k: v[val_indices] for k, v in features.items()},
        targets[val_indices]
    )
    test_data = (
        {k: v[test_indices] for k, v in features.items()},
        targets[test_indices]
    )

    train_features, train_targets = train_data
    val_features, val_targets = val_data
    test_features, test_targets = test_data

    cft = CrystalFourierTransformer(config, jnp.array(cubic_abc_combinations), jnp.array(hexagonal_abc_combinations), cubic_adj_matrices, hexagonal_adj_matrices, cubic_pretrained_state, hexagonal_pretrained_state, cubic_encoding_config, hexagonal_encoding_config)
    state = create_train_state(config, cft, train_features, key)
    save_model_components(config, config['ckpt_dir'])

    # train_features = preprocess_data_batch(cft, state, train_features)
    # val_features = preprocess_data_batch(cft, state, val_features)
    # test_features = preprocess_data_batch(cft, state, test_features)

    wandb.init(project=config['wandb_project'], config=config)

    # Training loop
    best_val_loss = float('inf')
    for epoch in range(config['num_epochs']):
        start_time = time.time()
        train_losses, train_maes = [], []
        key, shuffle_key = random.split(key)
        train_indices = random.permutation(shuffle_key, train_indices)
        
        for i in range(0, len(train_indices), config['batch_size']):
            batch_indices = train_indices[i:i+config['batch_size']]
            batch = {k: v[batch_indices] for k, v in train_features.items()}
            batch['targets'] = train_targets[batch_indices]
            
            key, dropout_key = random.split(key)
            state, metrics = train_step(cft.apply, state, batch, dropout_key)
            
            train_losses.append(metrics['loss'])
            train_maes.append(metrics['mae'])

        val_losses, val_maes = [], []
        for i in range(0, len(val_targets), config['batch_size']):
            batch = {k: v[i:i+config['batch_size']] for k, v in val_features.items()}
            batch['targets'] = val_targets[i:i+config['batch_size']]
            loss, mae = eval_step(cft.apply, state.params, state.batch_stats, batch)
            val_losses.append(loss)
            val_maes.append(mae)

        train_loss, train_mae = jnp.mean(jnp.array(train_losses)), jnp.mean(jnp.array(train_maes))
        val_loss, val_mae = jnp.mean(jnp.array(val_losses)), jnp.mean(jnp.array(val_maes))
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{config['num_epochs']}: "
              f"Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}, "
              f"Time: {epoch_time:.2f}s")

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_mae": train_mae,
            "val_loss": val_loss,
            "val_mae": val_mae,
            "epoch_time": epoch_time
        })

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoints.save_checkpoint(config['ckpt_dir'], state, step=epoch, keep=2)

    # Final evaluation on test set
    state = checkpoints.restore_checkpoint(config['ckpt_dir'], state)
    test_losses, test_maes = [], []
    for i in range(0, len(test_targets), config['batch_size']):
        batch = {k: v[i:i+config['batch_size']] for k, v in test_features.items()}
        batch['targets'] = test_targets[i:i+config['batch_size']]
        loss, mae = eval_step(cft.apply, state.params, state.batch_stats, batch)
        test_losses.append(loss)
        test_maes.append(mae)
    
    test_loss, test_mae = jnp.mean(jnp.array(test_losses)), jnp.mean(jnp.array(test_maes))
    print(f"Final Test Loss: {test_loss:.4f}, Final Test MAE: {test_mae:.4f}")
    wandb.log({
        "test_loss": test_loss,
        "test_mae": test_mae
    })
    wandb.finish()

if __name__ == '__main__':
    main()