import torch
from rewards import ProperPairwiseReward, load_true_protein
from chroma import Chroma
from chroma.data.protein import Protein
import numpy as np
import csv
from pathlib import Path

class FKSteeringSampler:

    
    def __init__(
        self,
        lmbda=5.0,
        n_chains=3,
        num_residuals=127,
        device="cuda",
        batch_size=1,
        steps=200,
        t_start=1.0,
        t_end=1e-3,               # avoid exact 0
        potential="difference",   # "difference" | "max" | "sum"
        ess_threshold=0.5,        # resample when ESS < ess_threshold * n_chains
    ):
        self.model = Chroma()

        self.device = device
        self.n_chains = n_chains
        self.batch_size = batch_size
        self.num_residuals = num_residuals

        self.steps = steps
        self.t_start = float(t_start)
        self.t_end = float(t_end)

        self.lmbda = float(lmbda)
        self.potential = potential
        print(f'self.potential: {self.potential}    ')
        assert self.potential in ["difference", "max", "sum"]

        self.ess_threshold = float(ess_threshold)

        self.reward_fn = ProperPairwiseReward("reference_proteins/7r5b.cif")
        
        self.X_ref, self.C_ref, self.S_ref = load_true_protein(
            "reference_proteins/7r5b.cif",
            device=self.device
        )
        self.backbone_network = self.model.backbone_network

        self.multiply_R = lambda Z, C: self.backbone_network.noise_perturb.base_gaussian._multiply_R(Z, C)


    def resample(self, x_t, G_t, step=0):
        """
        Resample the chains based on their weights.
        """
        # print(f'G_t at step {step}: {G_t[step]}   ')
        p_t = G_t[step] / torch.sum(G_t[step])  

        p_t_numpy = p_t.cpu().numpy()

        ess = 1.0 / torch.sum(p_t ** 2)
        if ess <= self.n_chains / 2:

            new_indices = np.random.choice(
                self.n_chains,
                size=self.n_chains,
                p=p_t_numpy,
                replace=True
            )
            new_indices = torch.tensor(new_indices, device=self.device)
            resampled_x = x_t[new_indices]
            return resampled_x
        else:
            return x_t

    def compute_rewards(self, x_hat_t): 
        """
        Compute rewards for the given samples.
        """
        rewards = self.reward_fn(x_hat_t)
        return rewards
    
    def calculate_initial_rewards(self, x_t): 
        R_t = torch.zeros(self.steps + 1, self.n_chains, device = self.device)
        G_t = torch.zeros(self.steps + 1, self.n_chains, device = self.device)


        # X_hat_t =  torch.zeros(self.n_chains, self.num_residuals, 4, 3, device = self.device)

        for k in range(self.n_chains):
            # need to multiply by R here 
            out = self.model.sample(
                sde_func="ode",
                protein_init=Protein.from_XCS(x_t[k], self.C_ref, self.S_ref),
                steps = 50
            )



            x_full_denoised, _, _ = out.to_XCS()

        
            R_t[-1, k] = self.compute_rewards(x_full_denoised)
            
            if self.potential == "difference": 
                G_t[-1, k] = torch.exp(0)
            else: 
                G_t[-1, k] = torch.exp(self.lmbda * R_t[-1, k])
            
        return R_t, G_t

    def initialize_particles(self, z): 

        R_t, G_t = self.calculate_initial_rewards(z)

        z = self.resample(z, G_t, step=self.steps)

        return z, R_t, G_t


    @torch.no_grad()
    def propagate_one_step(self, x_t, t0, t1):
        r_new = torch.empty(self.n_chains, device=self.device)

        log_path = Path("elbo_partial_log.csv")

        # create file + header if needed
        write_header = not log_path.exists()
        with open(log_path, "a", newline="") as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(["k", "elbo_partial"])

            for k in range(self.n_chains):

                print(f'integrating one step from {t0} to {t1}')
                assert t0 > t1

                out = self.model.sample(
                    sde_func="langevin",
                    protein_init=Protein.from_XCS(x_t[k], self.C_ref, self.S_ref),
                    tspan=(t0, t1),
                    steps=2,
                    design_method=None,
                )

                traj, _, _ = out.to_XCS()

                xhat_traj = self.model.sample(
                    sde_func="langevin",
                    protein_init=Protein.from_XCS(traj, self.C_ref, self.S_ref),
                    tspan=(t1, 1e-3),
                    steps = 50
                )

                xhat, _, _ = xhat_traj.to_XCS()

                results_partial_inf = Chroma().score_backbone(proteins=out)
                elbo_partial = results_partial_inf['elbo'].score

                print(f'partial inference elbo: {elbo_partial}')

                # 🔹 LOG HERE
                writer.writerow([k, float(elbo_partial)])

                r_new[k] = self.reward_fn(xhat)
                x_t[k] = traj

        return x_t, r_new
    def compute_weights(self, R, G, r_t, step_less_noisy, step_noisy):
        """
        Compute weights based on the rewards.
        """
        for k in range(self.n_chains):
            if self.potential == "difference":
                R[step_less_noisy, k] = r_t[k] 
                G[step_less_noisy, k] = torch.exp(self.lmbda * (R[step_less_noisy, k] - R[step_noisy, k]))
            if self.potential == "max":
                R[step_less_noisy, k] = torch.maximum(R[step_noisy, k], r_t[k])
                G[step_less_noisy, k] = torch.exp(self.lmbda * R[step_less_noisy, k])
            if self.potential == "sum":
                R[step_less_noisy, k] = r_t[k] + R[step_noisy, k]
                G[step_less_noisy, k] = torch.exp(self.lmbda * R[step_less_noisy, k])
        return R, G


    @torch.no_grad()
    def sample(
        self
                ):
        """
        Run the sampler and return samples from the coldest chain.
        """
        # Initialize particles and time grid
        z = torch.randn(
            self.n_chains,
            self.batch_size,
            self.num_residuals,
            4,
            3,
            device=self.device
        )

        z = torch.stack(
            [self.multiply_R(z[k], self.C_ref) for k in range(self.n_chains)],
            dim=0
        )

        time_grid = torch.linspace(
            1e-3, 1, self.steps+ 1, device=self.device
        )

        x_t, R_t, G_t = self.initialize_particles(z)    

        for i, i_m1 in zip(range(self.steps,1, -1 ), range(self.steps-1, 0, -1 )):

            t_i= time_grid[i]
            t_im1 = time_grid[i_m1]

            x_t, r_t = self.propagate_one_step(x_t, t_i, t_im1)

            R_t, G_t = self.compute_weights(R_t, G_t, r_t, step_noisy = i, step_less_noisy= i_m1)

            x_t = self.resample(x_t, G_t, step=i_m1)
        
        final_rewards = torch.empty(self.n_chains, device=self.device)
        prod_per_chain = torch.prod(G_t, dim=0)

        for k in range(self.n_chains):
            final_rewards[k] = self.compute_rewards(x_t[k])/ prod_per_chain[k]


        best_chain = torch.argmax(final_rewards)
        best_protein = Protein.from_XCS(x_t[best_chain], self.C_ref, self.S_ref)
        return best_protein, final_rewards 
