import torch
import math
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
    DDIMScheduler
)
from reward import CLIPScorer, ImageRewardScorer
import time
import torch
import csv
from pathlib import Path


class SourceTemperingSampler:
    """
    Source-space Parallel Tempering MCMC for a pretrained MeanFlow model.
    """

    def __init__(
        self,
        # reward_fn=None,                 
        beta=5.0,
        n_chains=10,
        device="cuda",
        batch_size=1,
        text_prompt="",
        model = "v1.4"
    ):
        self.device = device
        self.n_chains = n_chains
        self.batch_size = batch_size
        self.text_prompt = text_prompt
        # Tempering parameters
        self.betas = torch.linspace(
            0.0, beta, n_chains, device=device, dtype=torch.float16
        )

        self.thetas = torch.linspace(
            math.pi / 2, 0.05, n_chains, device=device, dtype=torch.float16
        )
        self.model = model
        self._load_pipeline(device)

        # -------------------------
        # Load SDXL ONCE
        # -------------------------

       


        self.pipe.scheduler = DDIMScheduler.from_config(
            self.pipe.scheduler.config
        )
        # ensure ddim remains deterministic
        self.pipe.scheduler.eta = 0.0

        self.pipe.set_progress_bar_config(disable=True)

        # -------------------------
        # Load CLIP ONCE
        # -------------------------
        # self.clip_scorer = CLIPScorer(device=device)
        self.reward_fn = ImageRewardScorer(device=device)
    # --------------------------------------------------
    # Transport: z -> images
    # --------------------------------------------------



    def _load_pipeline(self, device):
        model = self.model.lower()
        print(f'USING MODEL {model}')
        if model == "xl":
            self.pipe = StableDiffusionXLPipeline.from_pretrained(
                "stabilityai/stable-diffusion-xl-base-1.0",
                torch_dtype=torch.float16,
            )

        elif model == "v1.4":
            self.pipe = StableDiffusionPipeline.from_pretrained(
                "CompVis/stable-diffusion-v1-4",
                torch_dtype=torch.float16,
            )

        elif model == "v1.5":
            self.pipe = StableDiffusionPipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                torch_dtype=torch.float16,
            )

        elif model == "v2.1":
            raise ValueError(f"huggingface doesn't currently support 2.1")

            #cant load 2.1 for some reason? (models aren't on huggingface anymore for some reason)
            # self.pipe = StableDiffusionPipeline.from_pretrained(
            #     "stabilityai/stable-diffusion-2-1-base",
            #     torch_dtype=torch.float16,
            # )

        else:
            raise ValueError(f"Unknown model: {self.model}")

        self.pipe = self.pipe.to(device)

    # @torch.no_grad()
    # def transport(self, z):
    #     """
    #     z: (n_chains, n_imgs, C, 64, 64)
    #     returns: List[List[PIL.Image]]
    #     """
    #     n_chains, n_imgs, C, H, W = z.size()


    #     z_flat = z.view(n_chains * n_imgs, C, H, W)

    #     result = self.pipe(
    #         prompt=[self.text_prompt] * (n_chains * n_imgs),
    #         latents=z_flat,
    #         num_inference_steps=50, #default
    #         guidance_scale=7.5, #default
    #     )

    #     images = result.images  # List[PIL.Image]

    #     chains = [
    #         images[i * n_imgs : (i + 1) * n_imgs]
    #         for i in range(n_chains)
    #     ]

    #     return chains

    @torch.no_grad()
    def transport(self, z):
        """
        z: (n_chains, n_imgs, C, 64, 64)
        returns: List[List[PIL.Image]]
        """
        n_chains, n_imgs, C, H, W = z.size()

        all_chains = []

        for chain_idx in range(n_chains):
            # shape: (n_imgs, C, H, W)
            z_chain = z[chain_idx]

            result = self.pipe(
                prompt=[self.text_prompt] * n_imgs,
                latents=z_chain,
                num_inference_steps=50, #was 100, but fk steering uses 100
                guidance_scale=7.5,
            )

            # images for this chain only
            all_chains.append(result.images)

            # optional but helpful for tight VRAM
            # torch.cuda.empty_cache()

        return all_chains




    # --------------------------------------------------
    # Energy: CLIP pullback
    # --------------------------------------------------
    @torch.no_grad()
    def get_reward(self, z):
        """
        returns: Tensor (n_chains, n_imgs)
        """
        chains = self.transport(z)

        n_chains = len(chains)
        n_imgs = len(chains[0])

        images_flat = [
            img
            for chain in chains
            for img in chain
        ]
        
        rewards_flat = self.reward_fn.score_images(
            images_flat,
            text_prompt=self.text_prompt,
        )

        

        return rewards_flat.view(n_chains, n_imgs)

    # @torch.no_grad()
    # def score_image_chains(self, chains):
    #     """
    #     chains: List[List[PIL.Image]]
    #     returns: Tensor (n_chains, n_imgs)
        
    #     Differs from get_reward by not calling transport (only use at end)

    #     """
    #     n_chains = len(chains)
    #     n_imgs = len(chains[0])

    #     images_flat = [
    #         img
    #         for chain in chains
    #         for img in chain
    #     ]

    #     rewards_flat = self.reward_fn.score_images(
    #         images_flat,
    #         text_prompt=self.text_prompt,
    #     )  # (n_chains * n_imgs,)

    #     return rewards_flat.view(n_chains, n_imgs)

    @torch.no_grad()
    def propose_updates(self, z, energies):
        """
        pCN proposal + Metropolis accept/reject

        z:        (n_chains, n_imgs, C, H, W)
        energies: (n_chains, n_imgs)   # CLIP rewards

        returns:
            z_new:        same shape as z
            energies_new same shape as energies
        """
        K, B, C, H, W = z.shape

        # pCN parameters
        theta = self.thetas.view(K, 1, 1, 1, 1)      # per-chain
        beta_k = self.betas.view(K, 1)               # (K,1)

        # Gaussian noise
        xi = torch.randn_like(z)

        # pCN proposal
        z_prop = torch.cos(theta) * z + torch.sin(theta) * xi

        # Compute proposed rewards
        E_prop = self.get_reward(z_prop)              # (K, B)

        # Metropolis-Hastings log acceptance
        # log α = β_k (R_new − R_old)
        log_alpha = beta_k * (E_prop - energies)      # (K, B)

        # Accept/reject
        accept = torch.log(torch.rand_like(log_alpha)) < log_alpha  # (K, B)

        # Broadcast mask to latent shape
        accept_mask = accept.view(K, B, 1, 1, 1)

        # Apply accept/reject
        z_new = torch.where(accept_mask, z_prop, z)
        energies_new = torch.where(accept, E_prop, energies)

        return z_new, energies_new
    

    @torch.no_grad()
    def swap_between_chains(self, z, energies):
        """
        Replica exchange between adjacent chains.

        z:        (n_chains, n_imgs, C, H, W)
        energies: (n_chains, n_imgs)
        """
        for k in range(self.n_chains - 1):
            beta_k = self.betas[k].to(energies.dtype)
            beta_kp1 = self.betas[k + 1].to(energies.dtype)

            delta_beta = beta_kp1 - beta_k
            delta_E = energies[k] - energies[k + 1]

            log_alpha = delta_beta * delta_E
            log_alpha = torch.clamp(log_alpha, max=0.0)

            accept = torch.log(torch.rand_like(log_alpha)) < log_alpha

            if accept.any():
                z_k = z[k].clone()
                E_k = energies[k].clone()

                z[k, accept] = z[k + 1, accept]
                energies[k, accept] = energies[k + 1, accept]

                z[k + 1, accept] = z_k[accept]
                energies[k + 1, accept] = E_k[accept]

        return z, energies



    @torch.no_grad()
    def sample(self, n_iterations):



        z = torch.randn(
            (
                self.n_chains,
                self.batch_size,
                self.pipe.unet.config.in_channels,
                64,
                64,
            ),
            device=self.device,
            dtype=torch.float16,
        )

        energies = self.get_reward(z)
        # print(f'energies shape: {energies.size()}')
        for it in range(n_iterations):
            # print(f'Iteration Number {it +1}')

            z, energies = self.propose_updates(z, energies)


            z, energies = self.swap_between_chains(z, energies)




        cold_chain = z[-1].unsqueeze(0)
        best_images = self.transport(cold_chain)
        # rewards = self.score_images(best_images)

        return cold_chain.squeeze(0), best_images, energies[-1]




# its running in trememndous-ocelot right now
