import torch
import numpy as np
import torch.nn.functional as F
from modules import extract, betas_for_alpha_bar, linear_beta_schedule, exp_beta_schedule, cosine_beta_schedule

class BaseDiffusion():
    def __init__(self, config):
        self.config = config
        self.timesteps = config['timesteps'] 
        self.device = config['device']
        self.w = config['w']

        self.betas = self.get_beta_schedule(config['beta_sche'], config['timesteps'], config['beta_start'],
                                            config['beta_end'])
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) 
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1)

        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (
                1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

        self.init_ddim_variables()

    def get_beta_schedule(self, beta_sche, timesteps, beta_start, beta_end):
        if beta_sche == 'linear':
            return linear_beta_schedule(timesteps=timesteps, beta_start=beta_start, beta_end=beta_end)
        elif beta_sche == 'exp':
            return exp_beta_schedule(timesteps=timesteps)
        elif beta_sche == 'cosine':
            return cosine_beta_schedule(timesteps=timesteps)
        elif beta_sche == 'sqrt':
            return torch.tensor(betas_for_alpha_bar(timesteps, lambda t: 1 - np.sqrt(t + 0.0001))).float()
        else:
            raise ValueError("Invalid beta schedule")

    def init_ddim_variables(self):
        indices = list(range(0, self.timesteps + 1, self.config['ddim_step']))
        self.sub_timesteps = len(indices)
        indices_now = [indices[i] - 1 for i in range(len(indices))]
        indices_now[0] = 0
        self.alphas_cumprod_ddim = self.alphas_cumprod[indices_now]
        self.alphas_cumprod_ddim_prev = F.pad(self.alphas_cumprod_ddim[:-1], (1, 0), value=1.0)
        self.sqrt_recipm1_alphas_cumprod_ddim = torch.sqrt(1. / self.alphas_cumprod_ddim - 1)
        self.posterior_ddim_coef1 = torch.sqrt(self.alphas_cumprod_ddim_prev) - torch.sqrt(
            1. - self.alphas_cumprod_ddim_prev) / self.sqrt_recipm1_alphas_cumprod_ddim
        self.posterior_ddim_coef2 = torch.sqrt(1. - self.alphas_cumprod_ddim_prev) / torch.sqrt(
            1. - self.alphas_cumprod_ddim)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    @torch.no_grad()
    def i_sample(self, model, x, h_pos, h_neg, t, t_index):
        # Denoise under both positive and negative conditions
        predicted_x_pos = model.denoise(x, h_pos, t)
        predicted_x_neg = model.denoise(x, h_neg, t)

        # Apply the Positive-Negative Guidance (PNG) rule for inference.
        x_start = (1 + self.w) * predicted_x_pos - self.w * predicted_x_neg
        x_t = x
        model_mean = (
                self.posterior_ddim_coef1[t_index] * x_start +
                self.posterior_ddim_coef2[t_index] * x_t
        )
        return model_mean

    @torch.no_grad()
    def sample_from_noise(self, model, h_pos, h_neg):
        x = torch.randn_like(h_pos).to(h_pos.device) 
        for n in reversed(range(self.sub_timesteps)):
            step = torch.full((h_pos.shape[0],), n * self.config['ddim_step'], device=h_pos.device, dtype=torch.long)
            x = self.i_sample(model, x, h_pos, h_neg, step, n)
        return x


    @torch.no_grad()
    def p_sample(self, model, x, h_pos, h_neg, t, t_index):
        predicted_x_pos = model.denoise(x, h_pos, t)
        predicted_x_neg = model.denoise(x, h_neg, t)
        x_start = (1 + self.w) * predicted_x_pos - self.w * predicted_x_neg
        model_mean = (
            extract(self.posterior_mean_coef1, t, x.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x.shape) * x
        )
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def sample(self, model, h_pos, h_neg, mode='ddim'):
        if mode=='ddpm':
            x = torch.randn_like(h_pos)
            for i in reversed(range(0, self.timesteps)):
                step = torch.full((h_pos.shape[0],), i, device=h_pos.device, dtype=torch.long)
                x = self.p_sample(model, x, h_pos, h_neg, step, i)
        elif mode == 'ddim':
            x = self.sample_from_noise(model, h_pos, h_neg)
        else:
            raise ValueError(f"Unknown sampling mode: {mode}")
        return x


class SteerRecDiffusion(BaseDiffusion):
    def __init__(self, config):
        super().__init__(config=config)
        self.mu = config.get('mu', 0.4)
        self.margin = config.get('margin', 0.1)
    
    def p_losses(self, denoise_model, x_start_pos, h_pos, h_neg, t, noise=None, loss_type="l2"):
        if noise is None:
            noise = torch.randn_like(x_start_pos)

        def cosine_loss(pred, gt):
            pred_norm = F.normalize(pred, p=2, dim=-1)
            gt_norm = F.normalize(gt, p=2, dim=-1)
            return (torch.sum(pred_norm * gt_norm, dim=-1) - 1) ** 2

        loss_func = { 'l1': F.l1_loss, 'l2': F.mse_loss, 'huber': F.smooth_l1_loss, 'cosine': cosine_loss }.get(loss_type)
        if loss_func is None: raise NotImplementedError()

        x_noisy_pos = self.q_sample(x_start=x_start_pos, t=t, noise=noise)
        
        # Denoise the same noisy input under opposing guidance conditions.
        predicted_x_pos = denoise_model.denoise(x_noisy_pos, h_pos, t)
        predicted_x_neg = denoise_model.denoise(x_noisy_pos, h_neg, t) 

        dist_ap = loss_func(predicted_x_pos, x_start_pos)
        dist_an = loss_func(predicted_x_neg, x_start_pos)

        alignment_loss = torch.clamp(dist_ap - dist_an + self.margin, min=0)
        
        reconstruction_loss = dist_ap
        
        final_loss = (1 - self.mu) * reconstruction_loss.mean() + self.mu * alignment_loss.mean()



        loss_dict = {
            'main_loss': final_loss.detach().cpu(),
            'reconstruction_loss': reconstruction_loss.mean().detach().cpu(), 
            'alignment_loss': alignment_loss.mean().detach().cpu(), 
        }

        
        
        return final_loss, loss_dict

