import numpy as np
import jax.numpy as jnp
import jax
import flax.linen as nn
from flax.training import train_state
import optax
import jax.random
import os
import flax
import datetime  # Added for timestamp
try:
    import wandb
    import pybullet as pb
    import ray
    from sklearn.model_selection import train_test_split
except:
    print("import without PyBullet, Ray, wandb, and scikit-learn")

# ROBOT = 'ur5'
# ROBOT = 'shakey'
ROBOT = 'IM2'
# ROBOT = 'shakey_robotiq'

def data_gen(nitr=1000, visualize=False):
    # Check PyBullet connection
    if pb.getConnectionInfo()['isConnected'] == 0:
        if visualize:
            pb.connect(pb.GUI)
        else:
            pb.connect(pb.DIRECT)
        
        robot_uid = pb.loadURDF(
            {'shakey':'assets/ur5/urdf/shakey_open_rg6.urdf', 
             'shakey_robotiq':'assets/ur5/urdf/shakey_robotiq_open.urdf', 
             'ur5':'assets/ur5/urdf/ur5.urdf', 
             'IM2':'assets/RobotBimanualV4/urdf/RobotBimanualV4.urdf'}[ROBOT],
            flags=pb.URDF_USE_SELF_COLLISION
        )
        # get joint limits
        joint_lower_limits = []
        joint_upper_limits = []
        for i in range(pb.getNumJoints(robot_uid)):
            joint_info = pb.getJointInfo(robot_uid, i)
            joint_lower_limit = joint_info[8]
            joint_upper_limit = joint_info[9]
            joint_lower_limits.append(joint_lower_limit)
            joint_upper_limits.append(joint_upper_limit)
        joint_lower_limits = np.array(joint_lower_limits)
        joint_upper_limits = np.array(joint_upper_limits)
        if ROBOT == 'IM2':
            # remove collision between link 2 - 4 and 8 - 10
            pb.setCollisionFilterPair(robot_uid, robot_uid, 3, 5, 0)
            pb.setCollisionFilterPair(robot_uid, robot_uid, 9, 11, 0)
            pb.setCollisionFilterPair(robot_uid, robot_uid, 0, 2, 0)
            pb.setCollisionFilterPair(robot_uid, robot_uid, 0, 8, 0)

    if ROBOT in ['shakey', 'ur5', 'shakey_robotiq']:
        num_joints = 6
        joint_start_idx = 1
    else:
        num_joints = 12
        joint_start_idx = 1

    col_res_list = np.empty(shape=(nitr,), dtype=np.bool_)
    q_list = np.empty(shape=(nitr, num_joints), dtype=np.float32)
    for cur_itr in range(nitr):
        # random_joint = np.random.uniform(-2 * np.pi, 2 * np.pi, size=(num_joints,))
        random_joint = np.random.uniform(1.1*joint_lower_limits[joint_start_idx:], 1.1*joint_upper_limits[joint_start_idx:], size=(num_joints,))
        # random_joint = np.zeros_like(random_joint) # test
        def reset_q_pb(q):
            for i in range(num_joints):
                pb.resetJointState(robot_uid, i + joint_start_idx, q[i])

        reset_q_pb(random_joint)
        # Check self-collision
        pb.performCollisionDetection()
        col_data = pb.getContactPoints(robot_uid, robot_uid)
        # col_res_list.append(len(col_data) > 0)
        col_res_list[cur_itr] = len(col_data) > 0
        
        # if not col_res_list[-1] and visualize:
        #     print("Collision not detected!")
        # else:
        #     print("Collision detected!")

        q_list[cur_itr] = random_joint
        # q_list.append(random_joint)
    pb.disconnect()
    return np.array(q_list), np.array(col_res_list, dtype=np.float32)

# for _ in range(100):
#     q_list, col_res_list = data_gen(1000, visualize=True)

def ray_data_gen(nray_env=10, nitr=1000):
    res = ray.get([ray.remote(data_gen).remote(nitr) for _ in range(nray_env)])
    q_list, col_res_list = zip(*res)

    # positive negative balance
    col_res_list = np.concatenate(col_res_list, axis=0)
    q_list = np.concatenate(q_list, axis=0)

    pos_idx = np.where(col_res_list == 1)[0]
    neg_idx = np.where(col_res_list == 0)[0]
    n_pos = len(pos_idx)
    n_neg = len(neg_idx)
    n_sample = min(n_pos, n_neg)
    pos_idx = np.random.choice(pos_idx, n_sample, replace=False)
    neg_idx = np.random.choice(neg_idx, n_sample, replace=False)
    idx = np.concatenate([pos_idx, neg_idx], axis=0)
    q_list = q_list[idx]
    col_res_list = col_res_list[idx]

    return q_list, col_res_list

class ShakeySelfCollisionNet(nn.Module):
    @nn.compact
    def __call__(self, x, train=False):
        '''
        x: jnp.ndarray of shape (n, 6)
        return: jnp.ndarray of shape (n, 1)
        '''

        num_encoding_dims = 16
        emb = jnp.power(2, jnp.arange(0, num_encoding_dims)-3)
        x = x[..., None, :] * emb[..., :, None]
        x = jnp.concat([jnp.sin(x), jnp.cos(x)], axis=-2)
        x = x.reshape(*x.shape[:-2], -1)

        # Input Layer
        x = nn.Dense(128)(x)
        x = nn.swish(x)

        # Residual Block
        residual = x
        x = nn.Dense(128)(x)
        x = nn.swish(x)
        x = nn.Dense(128)(x)
        x = x + residual
        x = nn.swish(x)

        # Output Layer
        x = nn.Dense(1)(x)

        return x

def create_train_state(rng, learning_rate):
    model = ShakeySelfCollisionNet()
    dummy_q = jnp.ones([1, 6]) if ROBOT in ['shakey', 'ur5', 'shakey_robotiq'] else jnp.ones([1, 12])
    params = model.init(rng, dummy_q, train=True)['params']
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    '''
    Compute the focal loss between logits and targets.

    Args:
        logits: Logits from the model (before sigmoid), shape (batch_size, 1).
        targets: Ground truth labels, shape (batch_size, 1).
        alpha: Balancing factor.
        gamma: Modulating factor.

    Returns:
        Loss value.
    '''
    probs = nn.sigmoid(logits)
    pt = probs * targets + (1 - probs) * (1 - targets)
    w = alpha * targets + (1 - alpha) * (1 - targets)
    loss = -w * (1 - pt) ** gamma * jnp.log(pt + 1e-8)
    return jnp.mean(loss)

def loss_fn(params, x, y, alpha, gamma, train=False):
    logits = ShakeySelfCollisionNet().apply({'params': params}, x, train=train)
    # loss = focal_loss(logits, y, alpha, gamma)
    loss = optax.sigmoid_binary_cross_entropy(logits, y).mean()
    return loss

@jax.jit
def train_step(state, x, y, alpha, gamma):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, x, y, alpha, gamma, train=True)
    state = state.apply_gradients(grads=grads)
    acc = compute_accuracy(state.params, x, y)
    return state, loss, acc

def compute_accuracy(params, x, y):
    logits = ShakeySelfCollisionNet().apply({'params': params}, x)
    probs = nn.sigmoid(logits)
    preds = jnp.round(probs)
    accuracy = jnp.mean(preds == y)
    # return float(accuracy)
    return accuracy

if __name__ == "__main__":
    # Initialize wandb
    wandb.init(project='shakey_self_collision')

    # Add datetime to log_dir as unique id for this run
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    log_dir = f'logs_selfcollision/{timestamp}_{ROBOT}'
    os.makedirs(log_dir, exist_ok=True)

    # Parameters
    nray_env = 20
    # nitr = 30000
    nitr = 120000
    # nitr = 150000
    num_epochs = 200
    # batch_size = 256
    batch_size = 1024
    learning_rate = 1e-3
    rng = jax.random.PRNGKey(0)

    # Create model and training state
    state = create_train_state(rng, learning_rate)

    # Dataset refresh interval (e.g., every 30 epochs)
    dataset_refresh_interval = 30

    # Focal Loss parameters
    alpha = 0.25  # Balancing factor
    gamma = 2.0   # Modulating factor

    q_data = None
    col_res_data = None
    for epoch in range(1, num_epochs + 1):
        # Refresh dataset at intervals
        if (epoch - 1) % dataset_refresh_interval == 0:
            print(f"Refreshing dataset at epoch {epoch}")
            q_data_, col_res_data_ = ray_data_gen(nray_env=nray_env, nitr=nitr)
            if q_data is None:
                q_data = q_data_
                col_res_data = col_res_data_
            else:
                q_data = np.concatenate([q_data, q_data_], axis=0)
                col_res_data = np.concatenate([col_res_data, col_res_data_], axis=0)

            # Measure imbalance of dataset and compute class weights
            num_positive = np.sum(col_res_data)
            num_negative = len(col_res_data) - num_positive
            total = len(col_res_data)
            print(f"Dataset size: {total}, Positive samples: {num_positive}, Negative samples: {num_negative}")

            pos_weight = num_negative / total
            neg_weight = num_positive / total

            # Update alpha for focal loss based on class imbalance
            alpha = pos_weight

            # Split data into training and test sets
            X_train, X_test, y_train, y_test = train_test_split(
                q_data, col_res_data, test_size=0.05, random_state=42
            )
            # Convert data to JAX arrays
            X_train = jnp.array(X_train)
            y_train = jnp.array(y_train).reshape(-1, 1)
            X_test = jnp.array(X_test)
            y_test = jnp.array(y_test).reshape(-1, 1)

        # Shuffle data
        rng, input_rng = jax.random.split(rng)
        perm = jax.random.permutation(input_rng, len(X_train))
        X_train = X_train[perm]
        y_train = y_train[perm]

        # Batch training
        num_batches = len(X_train) // batch_size
        epoch_loss = 0.0
        epoch_acc = 0.0
        for i in range(num_batches):
            x_batch = X_train[batch_size*i:batch_size*i + batch_size]
            y_batch = y_train[batch_size*i:batch_size*i + batch_size]
            state, loss, acc = train_step(state, x_batch, y_batch, alpha, gamma)
            epoch_loss += loss
            epoch_acc += acc

        # Compute average loss
        avg_loss = epoch_loss / num_batches
        avg_acc = epoch_acc / num_batches

        # Compute accuracy
        # train_accuracy = compute_accuracy(state.params, X_train, y_train)
        test_accuracy = compute_accuracy(state.params, X_test, y_test)
        print(f"Epoch {epoch}, Loss: {avg_loss:.4f}, Train Accuracy: {avg_acc:.4f}, Test Accuracy: {test_accuracy:.4f}")

        # Log metrics to wandb
        wandb.log({
            'epoch': epoch,
            'loss': avg_loss,
            'train_accuracy': avg_acc,
            'test_accuracy': test_accuracy
        })

        # Save model to log directory
        if epoch % 10 == 0:
            model_save_path = os.path.join(log_dir, f'model_{ROBOT}_epoch_{epoch}.ckpt')
            with open(model_save_path, 'wb') as f:
                f.write(flax.serialization.to_bytes(state.params))
            print(f"Saved model at epoch {epoch} to {model_save_path}")

    # Disconnect Ray
    ray.shutdown()
