import torch
import math
from chroma import Chroma
from chroma.data.protein import Protein
from rewards import ProperPairwiseReward, load_true_protein

from chroma.layers.structure.rmsd import CrossRMSD

import pandas as pd
import os




def ca_rmsd(X_t, X_g, C_gt):
\

    mask = (C_gt[0] == 1)  # [R]

    ca_t = X_t[0, mask, :].reshape(1, -1, 3)
    ca_g = X_g[0, mask, :].reshape(1, -1, 3)

    rmsd, _ = CrossRMSD().pairedRMSD(ca_g.cpu(), ca_t.cpu(), compute_alignment=True)
    return rmsd



class SourceTemperingSampler:
    """
    Source-space Parallel Tempering MCMC for a pretrained MeanFlow model.
    Non-adaptive pCN + Parallel Tempering.
    """

    def __init__(
        self,
        beta=5.0,
        n_chains=10,
        num_residuals=127,
        device="cuda",
        batch_size = 1
    ):
        self.model = Chroma()

        self.device = device
        self.n_chains = n_chains

        self.betas = torch.linspace(0.0, beta, n_chains, device=device)

        self.thetas = torch.linspace(
            math.pi / 2, 0.05, n_chains, device=device
        )

        self.reward_fn = ProperPairwiseReward("reference_proteins/7r5b.cif")
        self.num_residuals = num_residuals
        self.batch_size = batch_size


        self.X_ref, self.C_ref, self.S_ref = load_true_protein(
            "reference_proteins/7r5b.cif",
            device=self.device
        )
        # print(f'X size: {self.X_ref.size()}')
        self.backbone_network = self.model.backbone_network
        self.multiply_R = lambda Z, C: self.backbone_network.noise_perturb.base_gaussian._multiply_R(Z, C)


    @torch.no_grad()
    
    def transport(self, z):


        device = z.device
        n_chains = z.size(0)
        N = self.num_residuals

        # Allocate after we know X's shape
        generated_X = None
        generated_C = torch.zeros((n_chains, 1, N), device=device)
        generated_S = torch.zeros((n_chains, 1, N), device=device)
        if n_chains > 1:
            for c in range(n_chains): # steering step

                random_backbone = Protein.from_XCS(z[c], self.C_ref, self.S_ref)

                protein = self.model.sample(
                    sde_func="ode",
                    protein_init=random_backbone,
                    steps=100
                )

                gen_X, gen_C, gen_S = protein.to_XCS()

                # Lazy allocation once X shape is known
                if generated_X is None:
                    generated_X = torch.zeros(
                        (n_chains, *gen_X.shape),
                        device=device
                    )

                generated_X[c] = gen_X
                generated_C[c] = gen_C
                generated_S[c] = gen_S
            return generated_X, generated_C, generated_S

        else: # final inference step 
            print(f"(Inference Step) Generating Target Protein!")
            random_backbone = Protein.from_XCS(z, self.C_ref, self.S_ref)

            protein = self.model.sample(
                sde_func="ode",
                protein_init=random_backbone, 
                steps=100
                
            )

            return protein


    @torch.no_grad()
    def get_energy(self, z):
        """
        Pullback energy: E(z) = R(T(z))
        z: [n_chains] (Compute reward on generated protein)
        Returns: [n_chains] reward for each generated protein
        """
        X, C, S = self.transport(z)

        rewards = []
        for i in range(self.n_chains): 

            rewards.append(self.reward_fn(X[i]))

        return torch.stack(rewards, dim=0)


    @torch.no_grad()
    def propose_updates(self, z, energies):
        """
        Propose PCN updates for each chain and accept/reject.
        """
        theta = self.thetas.view(self.n_chains, 1, 1, 1, 1)
        beta_k = self.betas.view(-1, 1)


        energies = energies.view(self.n_chains, 1)  # ← CRITICAL

        xi = torch.randn_like(z)
        # pCN proposal across all chains and batches shape ([K, B, D])
        z_prop = torch.cos(theta) * z + torch.sin(theta) * xi

        E_prop = self.get_energy(z_prop).view(self.n_chains, 1)


        log_alpha = beta_k * (E_prop - energies)


        accept = torch.rand_like(log_alpha).log() < log_alpha
        accept_mask = accept.view(self.n_chains, 1, 1, 1, 1)  # [K,1,1,1,1]
        



        z_new = torch.where(accept_mask, z_prop, z)
        energies_new = torch.where(accept, E_prop, energies).view(self.n_chains, 1)


        return z_new, energies_new

    # ------------------------------------------------------------------
    # Replica exchange
    # ------------------------------------------------------------------

    @torch.no_grad()


    def swap_between_chains(self, z, energies, ):

        for k in range(self.n_chains - 1):
            delta_beta = self.betas[k + 1] - self.betas[k]
            delta_E = energies[k] - energies[k+1]



            log_alpha = delta_beta * delta_E
            # for each chain, look at all batch points and decide whether to swap or not
            accept = torch.rand(z.shape[1], device=self.device).log() < log_alpha


            if accept.any():
                z_k = z[k].clone()
                E_k = energies[k].clone()

                # swap the accepted points
                z[k, accept] = z[k + 1, accept]
                energies[k, accept] = energies[k + 1, accept]

                z[k + 1, accept] = z_k[accept]
                energies[k + 1, accept] = E_k[accept]
        return z, energies

    # ------------------------------------------------------------------
    # Main sampling loop
    # ------------------------------------------------------------------
    

    def sample(self, n_iterations, output_path="sampling_metrics.csv"):
        """
        Run the sampler and return samples from the coldest chain.
        """
        # Initialize from prior
        z = torch.randn(
            self.n_chains,
            self.batch_size,
            self.num_residuals,
            4,
            3,
            device=self.device
        )

        z = torch.stack(
            [self.multiply_R(z[k], self.C_ref) for k in range(self.n_chains)],
            dim=0
        )

        energies = self.get_energy(z)
        stopped_early = False
        
        history = []

        for it in range(n_iterations):
            step_num = it + 1
            print(f'Entering SPT Step {step_num}')

            z, energies = self.propose_updates(z, energies)
            z, energies = self.swap_between_chains(z, energies)

            # Access coldest chain (last index)
            ideal_x, ideal_protein_c, ideal_protein_s = self.transport(z[-1]).to_XCS()

            rmsd = ca_rmsd(ideal_x, self.X_ref, self.C_ref)
            rmsd_val = rmsd.item() if torch.is_tensor(rmsd) else float(rmsd)
            
            # Scoring the backbone for ELBO
            scores = Chroma().score_backbone(Protein.from_XCS(ideal_x, ideal_protein_c, ideal_protein_s))
            elbo_val = scores['elbo'].item() if torch.is_tensor(scores['elbo']) else scores['elbo']
            
 
            if rmsd_val < 14.0:
                print(f'Early stop: RMSD {rmsd_val:.3f} < 14.0 at SPT Step {step_num}')
                stopped_early = True
                break

            print(f'RMSD: {rmsd_val} | ELBO: {elbo_val.score} at SPT Step {step_num}')


        last_chain = z[-1]
        ideal_protein = Protein.from_XCS(ideal_x, ideal_protein_c, ideal_protein_s)

        return last_chain, ideal_protein
