import torch 

def get_energy(self, z):
        """
        Pullback energy: E(z) = R(T(z))
        z: [n_chains] (Compute reward on generated protein)
        Returns: [n_chains] reward for each generated protein
        """
        # 1. Transport latent z to data space x (maintains shape [K, B, D])
        X, C, S = transport(z)
        # torch.save((X, C, S), "random_scripts/test_backbone.pt")

        rewards = []
        for i in range(self.n_chains): 

            rewards.append(self.reward_fn(X[i]))

        return rewards



def propose_updates(z, energies):
    """
    Propose PCN updates for each chain and accept/reject.
    """
    beta_k = betas.view(-1, 1)

    
    theta = thetas.view(n_chains, 1, 1, 1, 1)

    xi = torch.randn_like(z)

    z_prop = torch.cos(theta) * z + torch.sin(theta) * xi

    print(f'beta shape: {beta.shape}')


beta=5.0

batch_size = 1 # can only make one protein per chain
n_chains=2
num_residuals=127

z = torch.randn(
    n_chains,
    batch_size,
    num_residuals,
    4,
    3,
    device="cuda"
)
        
betas = torch.linspace(0.0, beta, n_chains, device="cuda")


energies = torch.tensor([1, 1])
thetas = torch.linspace(
            torch.pi / 2, 0.05, n_chains, device="cuda"
        )


propose_updates(z, energies)