import torch
import math
from chroma import Chroma
from chroma.data.protein import Protein
from rewards import ProperPairwiseReward, load_true_protein

from chroma.layers.structure.rmsd import CrossRMSD




def calculate_all_atom_RMSD(X_t, X_g):

    rmsd, _ = CrossRMSD().pairedRMSD(
        X_g.cpu().reshape(1, -1, 3),
        X_t.cpu().reshape(1, -1, 3),
        compute_alignment=True,
    )
    return rmsd




class SourceTemperingSampler:
    """
    Source-space Parallel Tempering MCMC for a pretrained MeanFlow model.
    Non-adaptive pCN + Parallel Tempering.
    """

    def __init__(
        self,
        beta=5.0,
        n_chains=10,
        num_residuals=127,
        device="cuda",
        batch_size = 1
    ):
        self.model = Chroma()

        self.device = device
        self.n_chains = n_chains

        self.betas = torch.linspace(0.0, beta, n_chains, device=device)

        self.thetas = torch.linspace(
            math.pi / 2, 0.05, n_chains, device=device
        )

        self.reward_fn = ProperPairwiseReward("reference_proteins/7r5b.cif")
        self.num_residuals = num_residuals
        self.batch_size = batch_size


        self.X_ref, self.C_ref, self.S_ref = load_true_protein(
            "reference_proteins/7r5b.cif",
            device=self.device
        )

        self.backbone_network = self.model.backbone_network

        self.multiply_R = lambda Z, C: self.backbone_network.noise_perturb.base_gaussian._multiply_R(Z, C)
        self.multiply_R_inverse = lambda X, C: self.backbone_network.noise_perturb.base_gaussian._multiply_R_inverse(X, C)

    @torch.no_grad()
    
    def transport(self, z):


        device = z.device
        n_chains = z.size(0)
        N = self.num_residuals

        # Allocate after we know X's shape
        generated_X = None
        generated_C = torch.zeros((n_chains, 1, N), device=device)
        generated_S = torch.zeros((n_chains, 1, N), device=device)


        C = torch.ones((1, N), device=device)
        S = torch.zeros_like(C)

        # C = self.C_ref.expand(n_chains, -1)
        # S = self.S_ref.expand(n_chains, -1)

        # print(f'C shape: {C.shape}, S shape: {S.shape}')
        if n_chains > 1:
            for c in range(n_chains): # steering step
                print(f"(Steering Step) Generating Protein {c + 1}")
                # X_init = self.multiply_R(z[c], self.C_ref)          # shape should match what Protein.from_XCS expects

                # 2) create Protein and run PF-ODE
                # random_backbone = Protein.from_XCS(X_init, self.C_ref, self.S_ref)
                random_backbone = Protein.from_XCS(z[c], C, S)

                protein = self.model.sample(
                    sde_func="ode",
                    protein_init=random_backbone,
                    steps=500
                )

                gen_X, gen_C, gen_S = protein.to_XCS()

                # Lazy allocation once X shape is known
                if generated_X is None:
                    generated_X = torch.zeros(
                        (n_chains, *gen_X.shape),
                        device=device
                    )

                generated_X[c] = gen_X
                generated_C[c] = gen_C
                generated_S[c] = gen_S
            return generated_X, generated_C, generated_S

        else: # final inference step 
            print(f"(Inference Step) Generating Target Protein!")
            random_backbone = Protein.from_XCS(z, C, S)

            protein = self.model.sample(
                sde_func="ode",
                protein_init=random_backbone, 
                steps=500
                
            )
            return protein


    @torch.no_grad()
    def get_energy(self, z):
        """
        Pullback energy: E(z) = R(T(z))
        z: [n_chains] (Compute reward on generated protein)
        Returns: [n_chains] reward for each generated protein
        """
        # 1. Transport latent z to data space x (maintains shape [K, B, D])
        X, C, S = self.transport(z)
        # torch.save((X, C, S), "random_scripts/test_backbone.pt")

        rewards = []
        for i in range(self.n_chains): 

            rewards.append(self.reward_fn(X[i]))

        return torch.stack(rewards, dim=0)


    @torch.no_grad()
    def propose_updates(self, z, energies):
        """
        Propose PCN updates for each chain and accept/reject.
        """
        theta = self.thetas.view(self.n_chains, 1, 1, 1, 1)
        beta_k = self.betas.view(-1, 1)
        energies = energies.view(self.n_chains, 1)  # ← CRITICAL

        xi = torch.randn_like(z)
        # pCN proposal across all chains and batches shape ([K, B, D])
        z_prop = torch.cos(theta) * z + torch.sin(theta) * xi

        E_prop = self.get_energy(z_prop).view(self.n_chains, 1)

        # Metropolis acceptance
        # Equivalent to: log( P(accept) ) = beta * (Reward_new - Reward_old)
        log_alpha = beta_k * (E_prop - energies)


        accept = torch.rand_like(log_alpha).log() < log_alpha
        accept_mask = accept.view(self.n_chains, -1, 1, 1, 1)



        z_new = torch.where(accept_mask, z_prop, z)
        energies_new = torch.where(accept, E_prop, energies).view(self.n_chains, 1)

        return z_new, energies_new

    # ------------------------------------------------------------------
    # Replica exchange
    # ------------------------------------------------------------------

    @torch.no_grad()


    def swap_between_chains(self, z, energies, ):

        for k in range(self.n_chains - 1):
            delta_beta = self.betas[k + 1] - self.betas[k]
            # delta_E = energies[k + 1] - energies[k]
            delta_E = energies[k] - energies[k+1]



            log_alpha = delta_beta * delta_E
            # for each chain, look at all batch points and decide whether to swap or not
            accept = torch.rand(z.shape[1], device=self.device).log() < log_alpha


            if accept.any():
                z_k = z[k].clone()
                E_k = energies[k].clone()

                # swap the accepted points
                z[k, accept] = z[k + 1, accept]
                energies[k, accept] = energies[k + 1, accept]

                z[k + 1, accept] = z_k[accept]
                energies[k + 1, accept] = E_k[accept]
        return z, energies

    # ------------------------------------------------------------------
    # Main sampling loop
    # ------------------------------------------------------------------

    @torch.no_grad()
    def sample(
        self,
        n_iterations,
    ):
        """
        Run the sampler and return samples from the coldest chain.
        """
        # Initialize from prior
        z = torch.randn(
            self.n_chains,
            self.batch_size,
            self.num_residuals,
            4,
            3,
            device=self.device
        )
        # z = torch.stack(
        #     [self.multiply_R(z[k], self.C_ref) for k in range(self.n_chains)],
        #     dim=0
        # )
        energies = self.get_energy(z)

        for it in range(n_iterations):
            print(f'Entering SPT Step {it+1}')
            z, energies = self.propose_updates(z, energies)
            # print(f'Successfully Proposed!')
            z, energies = self.swap_between_chains(z, energies)
            # print(f'Successfully Swapped!')
            last_chain = z[-1]
            ideal_protein, _, _= self.transport(last_chain).to_XCS()
            rmsd = calculate_all_atom_RMSD(ideal_protein, self.X_ref)
            print(f'RMSD: {rmsd} at SPT Step {it + 1}')
            print(f'Reward {energies[-1]} at SPT Step {it+ 1}')
            
        last_chain = z[-1]
        ideal_protein = self.transport(last_chain)
        return last_chain, ideal_protein
