import os
from pathlib import Path
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

def save_model(model, train_iterations=0, prefix="checkpoint", output_save_dir="./trained_models"):

    os.makedirs(Path(output_save_dir, 'model/'), exist_ok=True)

    # save yaml file for later runs
    # with open(output_save_dir + '/model/model.yaml', 'w') as yaml_file:
    #     yaml.dump(dict(config), yaml_file, default_flow_style=False)    
    save_dir_model = output_save_dir + f'/model/{prefix}_' + str(train_iterations) + '.pt'
    save_obj = dict(
        model = model.state_dict()
    )
    # save model to path
    with open(save_dir_model, 'wb') as f:
        torch.save(save_obj, f)
    print(f'\ncheckpoint saved to {output_save_dir}/.')

class MPFourier(torch.nn.Module):
    def __init__(self, num_channels, bandwidth=1):
        super().__init__()
        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))

    def forward(self, x):
        if x.ndim > 1:
            x = x.squeeze(-1)
        elif x.ndim < 1:
            x = x.unsqueeze(-1)
        y = x.to(torch.float32)
        y = y.ger(self.freqs.to(torch.float32))
        y = y + self.phases.to(torch.float32)
        y = y.cos() * np.sqrt(2)
        return y.to(x.dtype)

class ConditionalLinear(nn.Module):
    def __init__(self, num_in, num_out, bandwidth=1, num_channels=50, embnetworklayers=1):
        super(ConditionalLinear, self).__init__()
        self.num_out = num_out
        self.lin = nn.Linear(num_in, num_out)
        self.embed = MPFourier(num_channels, bandwidth)
        if embnetworklayers == 1:
            self.embnetwork = nn.Linear(num_channels,num_out)
        else:
            self.embnetwork = nn.Sequential(*[(nn.Linear(num_channels, num_channels), nn.Softplus()) for _ in range(embnetworklayers-1)] + 
                                           nn.Linear(num_channels,num_out))

    def forward(self, x, sigma):
        out = self.lin(x)
        gamma = self.embnetwork(self.embed(sigma))
        out = gamma * out
        return out

class ConditionalModel(nn.Module):
    def __init__(
            self, 
            dim=2,
            n_embed=128,
            bandwidth=10,
            num_channels=100,
            embnetworklayers=1,
        ):
        super(ConditionalModel, self).__init__()
        self.lin1 = ConditionalLinear(dim, n_embed, bandwidth, num_channels, embnetworklayers)
        self.lin2 = ConditionalLinear(n_embed, n_embed, bandwidth, num_channels, embnetworklayers)
        self.lin3 = nn.Linear(n_embed, dim)
    
    def forward(self, x, sigma):
        x = F.softplus(self.lin1(x, sigma))
        x = F.softplus(self.lin2(x, sigma))
        x = self.lin3(x)
        return x

class Precond(torch.nn.Module):
    def __init__(self,
        use_fp16        = False, # Run the model at FP16 precision?
        sigma_data      = 0.5,   # Expected standard deviation of the training data.
        **net_kwargs,            # Keyword arguments for the network.
    ):
        super().__init__()
        self.use_fp16 = use_fp16
        self.sigma_data = sigma_data
        self.net = ConditionalModel(**net_kwargs)

    def forward(self, x, sigma, force_fp32=False, **net_kwargs):
        x = x.to(torch.float32)
        sigma = sigma.to(torch.float32).reshape(-1, 1)
        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32

        # Preconditioning weights.
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.flatten().log() / 4

        # Run the model.
        x_in = (c_in * x).to(dtype)
        F_x = self.net(x_in, c_noise, **net_kwargs)
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x

def expand_to_batch(x, batch_size):
    assert x.ndim == 2
    if x.shape[0] == 1:
        return x.expand(batch_size, *x.shape[1:])
    elif x.shape[0] == batch_size:
        return x
    else:
        raise ValueError(f"Input shape {x.shape} is not compatible with batch size {batch_size}")

class SoftConstraintDiffusion(nn.Module):
    def __init__(
            self, 
            dim,
            n_embed,
            constraint_f,
            sigma_data = 0.5,
            n_mc = 8,
            lam=1.0,
            use_guidance=False,
            guidance_scale=1.0,
            deterministic=False,
            autoscale_grads=False,
            grad_wrt_denoising_output=False,
            **net_kwargs
        ):
        super(SoftConstraintDiffusion, self).__init__()
        self.net = Precond(use_fp16=False, sigma_data=sigma_data, dim=dim, n_embed=n_embed, **net_kwargs)
        self.constraint_f = constraint_f
        self.lam = lam
        self.n_mc = n_mc
        self.guidance_scale = guidance_scale
        self.use_guidance = use_guidance
        self.deterministic = deterministic
        self.grad_wrt_denoising_output = grad_wrt_denoising_output

        self.embed_sigma = MPFourier(num_channels=100, bandwidth=1)
        self.autoscale_mlp = nn.Sequential(
            nn.Linear(102, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU(),
            nn.Linear(100, 1),
        )
        self.autoscale_grads = autoscale_grads

    def forward(self, x, sigma, lam = None, n_mc = None, guidance_scale = None, use_guidance=None, deterministic=None, return_grads=False, **kwargs):
        if guidance_scale is None:
            guidance_scale = self.guidance_scale
        if n_mc is None:
            n_mc = self.n_mc
        if use_guidance is None:
            use_guidance = self.use_guidance
        if lam is None:
            lam = self.lam
        if deterministic is None:
            deterministic = self.deterministic
        
        batch_size = x.shape[0]
        if not self.grad_wrt_denoising_output:
            x = x.requires_grad_()
        original_x_hat = self.net(x, sigma)
        if self.grad_wrt_denoising_output:
            original_x_hat = original_x_hat.requires_grad_()
        if use_guidance:
            if deterministic:
                lse_input = torch.vmap(lambda x: -self.constraint_f(x))( original_x_hat.unsqueeze(0) )
            else:
                samples = torch.sqrt(sigma.pow(2)/(1+sigma.pow(2)))*torch.randn([n_mc, *x.shape], device=x.device)
                lse_input = torch.vmap(lambda x: -self.constraint_f(x))( original_x_hat + samples )
            grads = torch.autograd.grad( torch.logsumexp(lse_input, dim=0),\
                                        original_x_hat if self.grad_wrt_denoising_output else x, grad_outputs=torch.ones(batch_size, device=x.device), create_graph=True, retain_graph=True )[0]
            if not self.autoscale_grads:
                guidance_term = guidance_scale*torch.exp(-lam*sigma)*sigma.pow(2)*grads
            else:
                scale_factor = self.autoscale_mlp(torch.cat([expand_to_batch(self.embed_sigma(sigma), batch_size), x], dim=-1))
                guidance_term = grads * scale_factor * guidance_scale*torch.exp(-lam*sigma)#*sigma.pow(2)
            x_hat = original_x_hat + guidance_term
            if return_grads:
                return x_hat, original_x_hat.detach(), guidance_term.detach()
        else:
            x_hat = original_x_hat
        return x_hat
        
def edm_sampler(
    net, noise, gnet=None,
    num_steps=100, sigma_min=0.1, sigma_max=80, rho=10, guidance=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    dtype=torch.float32, randn_like=torch.randn_like,
    **net_kwargs
):
    # Guided denoiser.
    def denoise(x, t):
        Dx = net(x, t, **net_kwargs).to(dtype)
        if guidance == 1:
            return Dx
        ref_Dx = gnet(x, t).to(dtype)
        return ref_Dx.lerp(Dx, guidance)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = noise.to(dtype) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        if S_churn > 0 and S_min <= t_cur <= S_max:
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
            t_hat = t_cur + gamma * t_cur
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        else:
            t_hat = t_cur
            x_hat = x_cur

        # Euler step.
        d_cur = ( (x_hat - denoise(x_hat, t_hat)) / t_hat ).detach()
        x_next = ( x_hat + (t_next - t_hat) * d_cur ).detach()

        # Apply 2nd order correction.
        if i < num_steps - 1:
            d_prime = ( (x_next - denoise(x_next, t_next)) / t_next ).detach()
            x_next = ( x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) ).detach()

    return x_next

import torch.distributions as dist

def sample_truncated_normal_inverse_cdf(mean, std, min_val, sample_shape=(1000,)):
    """
    Sample from truncated normal using inverse transform sampling.
    Works with any PyTorch version.
    
    Args:
        mean (float): Mean of the underlying normal distribution
        std (float): Standard deviation of the underlying normal distribution
        min_val (float): Minimum value (lower bound)
        sample_shape (tuple): Shape of samples to generate
    
    Returns:
        torch.Tensor: Samples from truncated normal distribution
    """
    # Create standard normal distribution
    normal = dist.Normal(0, 1)
    
    # Calculate the CDF value at the truncation point
    alpha = (min_val - mean) / std
    cdf_alpha = normal.cdf(torch.tensor(alpha))
    
    # Sample uniform random variables
    u = torch.rand(sample_shape)
    
    # Transform to truncated normal using inverse CDF
    # u_scaled ranges from cdf_alpha to 1
    u_scaled = cdf_alpha + u * (1 - cdf_alpha)
    
    # Apply inverse CDF (quantile function)
    z = normal.icdf(u_scaled)
    
    # Transform back to desired mean and std
    samples = mean + std * z
    
    return samples

def main(dataset_type="normal", n_points = 10000, radius=1, center=torch.tensor([0.1, 0.1]), chop_angle=torch.pi/4, lr=5e-4, 
        loss_mean=-1.2, loss_std=1.0, name='', lam=1, guidance_scale=1, autoscale_grads=False, residual_func=lambda x: torch.abs(torch.sum(x**2, dim=-1) - 1.0),
        sigma_min=1e-9, grad_wrt_denoising_output=False):

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    output_save_dir = f'./trained_models/toy/{config["name"]}'
    os.makedirs(output_save_dir, exist_ok=True)

    radii = torch.ones(n_points)
    thetas = torch.linspace(-torch.pi, torch.pi, n_points)
    
    if dataset_type == "radius":
        radii = radius * radii

    y_vals = radii*torch.sin(thetas)
    x_vals = torch.zeros_like(y_vals)
    # Chopped circle
    if dataset_type == "chop":
        alpha = torch.tensor( chop_angle )
        if alpha.isclose( torch.tensor(torch.pi/2) ):
            x_vals[thetas.abs() >= alpha] = radii[thetas.abs() >= alpha]*torch.cos(thetas[thetas.abs() >= alpha])
        else:
            radii[thetas.abs() < alpha] = torch.cos(alpha)/torch.cos( thetas[thetas.abs() < alpha] )
            x_vals = radii*torch.cos(thetas)
    elif dataset_type == "shift":
        x_vals += center[0]
        y_vals += center[1]
    elif dataset_type == "dent":
        # ------- Parameters -------
        eps = 0.25        # dent depth (bigger = deeper)
        Delta = 1.2      # angular half-width (radians)
        theta_zero = torch.pi/2  # dent center angle (radians); pi/2 = top
        cc = 0.60         # shoulder shaping (0..~0.8)

        def wrap(angle):
            """Map angle to (-pi, pi]"""
            return ((angle + torch.pi) % (2 * torch.pi)) - torch.pi

        def pbump(u, k):
            """Compact-support polynomial bump: (1-u^2)^k for |u|<1, else 0"""
            return torch.where(torch.abs(u) < 1, (1 - u**2)**k, 0)

        def rr(theta, kk=5):
            """Radius r(θ): exactly 1 outside the selected arc"""
            u = wrap(theta - theta_zero) / Delta
            B = pbump(u, kk)
            return 1 - eps * B * (1 + cc * (1 - 2 * u**2))

        # Create parameter array (equivalent to domain=0:2*pi with samples=n_points)
        t = torch.linspace(-torch.pi, torch.pi, n_points)

        # Calculate parametric coordinates
        x_vals = rr(t) * torch.cos(t)
        y_vals = rr(t) * torch.sin(t)
    elif dataset_type == "normal":
        x_vals = radii*torch.cos(thetas)
    else:
        raise ValueError(f"Invalid dataset type: {dataset_type}")
    
    dataset = torch.hstack([x_vals.unsqueeze(1), y_vals.unsqueeze(1)]).to(device)

    sigma_data = dataset.std()
    print(f"sigma_data:{sigma_data}", flush=True)
    # dataset = dataset / dataset.std()

    class EDM2Loss:
        def __init__(self, P_mean=-2, P_std=1.2, sigma_data=0.5, sigma_min=1e-9):
            self.P_mean = P_mean
            self.P_std = P_std
            self.sigma_data = sigma_data
            self.sigma_min = sigma_min
        
        def __call__(self, net, in_data, sigma_min=-np.inf,**net_fwd_kwargs):

            # truncated_normal = dist.TruncatedNormal(
            #     loc=self.P_mean, 
            #     scale=self.P_std, 
            #     low=np.log(self.sigma_min),
            #     high=float('inf')  # No upper bound
            # )
            # rnd_normal = truncated_normal.sample([in_data.shape[0], 1]).to(in_data.device)
            rnd_normal = sample_truncated_normal_inverse_cdf(self.P_mean, self.P_std, np.log(self.sigma_min), [in_data.shape[0], 1]).to(in_data.device)
            # rnd_normal = torch.randn([in_data.shape[0], 1], device=in_data.device)
            sigma = rnd_normal.exp()#(rnd_normal * self.P_std + self.P_mean).exp()
            weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
            noise = torch.randn_like(in_data) * sigma
            denoised = net(in_data + noise, sigma, **net_fwd_kwargs)
            loss = (weight) * ((denoised - in_data) ** 2)
            return loss.mean(dim=-1), sigma

    # Define residual function
    class ResidualFunc(nn.Module):
        '''
        Simple residual given by the unit hypersphere. (Replace this with your own residual evaluation.)
        '''
        def __init__(self, residual_func):
            super(ResidualFunc, self).__init__()
            self.residual_func = residual_func

        def forward(self, x, scale=1):
            return scale*self.residual_func(x)

    residual_func = ResidualFunc(residual_func)

    def optimizer_step(optimizer, loss_fn, model, batch_x, **net_kwargs):
        optimizer.zero_grad()
        loss_not_agg, sigma = loss_fn(model, batch_x, **net_kwargs)
        loss = loss_not_agg.mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        # pbar.set_description(f'{model_name} training loss: {loss.item():.4f}')
        return loss, loss_not_agg, sigma

    # Define models and optimizers
    n_embed = 128
    # model_van = ConstraintGradientGuidance(config['dim'], n_embed=n_embed, constraint_f=residual_func, sigma_data=sigma_data).to(device)
    model_van = SoftConstraintDiffusion(config['dim'], n_embed=n_embed, constraint_f=residual_func, 
                                           bandwidth=1, sigma_data=sigma_data).to(device)
    optimizer_van = optim.Adam(model_van.parameters(), lr=lr, betas=(0.9, 0.999))
    model_cgg = SoftConstraintDiffusion(config['dim'], n_embed=n_embed, constraint_f=residual_func, sigma_data=sigma_data, deterministic=True,
                                           bandwidth=1, n_mc=128, lam=lam, guidance_scale=guidance_scale, use_guidance=True, autoscale_grads=autoscale_grads,
                                           grad_wrt_denoising_output=grad_wrt_denoising_output).to(device)
    optimizer_cgg = optim.Adam(model_cgg.parameters(), weight_decay=1e-6, lr=lr, betas=(0.9, 0.999))

    P_mean = loss_mean
    P_std = loss_std
    loss_fn_van = EDM2Loss(sigma_data=sigma_data, P_mean=P_mean, P_std=P_std, sigma_min=sigma_min)
    loss_fn_cgg = EDM2Loss(sigma_data=sigma_data, P_mean=P_mean, P_std=P_std, sigma_min=sigma_min)

    van_loss_list = []
    ours_loss_list = []
    van_samples_list = []
    ours_samples_list = []
    van_not_agg_list = []
    cgg_not_agg_list = []
    van_sigmas_list = []
    cgg_sigmas_list = []

    for t in range(config['train_num_steps']):
        permutation = torch.randperm(dataset.size()[0])
        van_avg_epoch_loss = torch.tensor(0., device='cuda')
        cgg_avg_epoch_loss = torch.tensor(0., device='cuda')
        for i in range(0, dataset.size()[0], config['batch_size']):
            indices = permutation[i:i + config['batch_size']]
            batch_x = dataset[indices]
            loss_van, van_loss_not_agg, van_sigma = optimizer_step(optimizer_van, loss_fn_van, model_van, batch_x)
            loss_cgg, cgg_loss_not_agg, cgg_sigma = optimizer_step(optimizer_cgg, loss_fn_cgg, model_cgg, batch_x)
            van_avg_epoch_loss += loss_van.detach()
            cgg_avg_epoch_loss += loss_cgg.detach()

        van_avg_epoch_loss /= dataset.shape[0]//config['batch_size']
        cgg_avg_epoch_loss /= dataset.shape[0]//config['batch_size']

        if t%1 == 0:
            # van_loss_list.append(loss_van.detach().cpu())
            # ours_loss_list.append(loss_cgg.detach().cpu())
            van_loss_list.append(van_avg_epoch_loss.detach().cpu())
            ours_loss_list.append(cgg_avg_epoch_loss.detach().cpu())
            van_not_agg_list.append(van_loss_not_agg.detach().cpu())
            cgg_not_agg_list.append(cgg_loss_not_agg.detach().cpu())
            van_sigmas_list.append(van_sigma.detach().cpu())
            cgg_sigmas_list.append(cgg_sigma.detach().cpu())

        if t%(config['train_num_steps'] // 10) == 0:
            noise = torch.randn([200, 2]).to(device)
            van_samples_list.append( edm_sampler(model_van, noise).detach().cpu() )
            ours_samples_list.append( edm_sampler(model_cgg, noise).detach().cpu() )
            print(t, van_avg_epoch_loss.detach().cpu())
            print(t, cgg_avg_epoch_loss.detach().cpu())
    print("done")

    save_model(model_van, config["train_num_steps"], f"vanillaNoStdization_{dataset_type}{name}", output_save_dir)
    save_model(model_cgg, config["train_num_steps"], f"dedicoNoStdization_{dataset_type}{name}", output_save_dir)

    return model_van, model_cgg, van_loss_list, ours_loss_list, van_samples_list, ours_samples_list, van_not_agg_list, cgg_not_agg_list, van_sigmas_list, cgg_sigmas_list

config = {
    'name': 'run_figs', # Name of the run
    'dim': 2, # Dimension of the hypersphere to be learned (2 for unit circle)
    'n_steps': 100, # Number of denoising timesteps
    'train_num_steps': 4000, # Number of training steps
    'batch_size': 128, # Batch size
}

dataset_type = "chop"
n_points = 1_000 # Alternate setup: n_points=100000, train_num_steps=40
radius = 1.0
center_x = 0.0
center_y = 0.0
chop_angle = torch.pi/3
lr = 5e-4
# loss_mean = -1.2
# loss_std = 1.2
loss_mean = -2.0
loss_std = 1.7
name = 1.0
lam = 1
guidance_scale = 1
autoscale_grads = True
# residual_func = lambda x: torch.abs(torch.sum(x**2, dim=-1) - 1.0)
residual_func = lambda x: (torch.sqrt(torch.sum(x**2, dim=-1)) - 1.0)**2
sigma_min = 1e-3
grad_wrt_denoising_output = True # if True, then we DONT backprop through the denoiser

model_van, model_constrained, van_loss_list, ours_loss_list, van_samples_list, ours_samples_list, van_not_agg_list, cgg_not_agg_list, van_sigmas_list, cgg_sigmas_list = \
    main(dataset_type=dataset_type, n_points=n_points, radius=radius, 
        center=torch.tensor([center_x, center_y]), chop_angle=chop_angle,
        lr=lr, loss_mean=loss_mean, loss_std=loss_std, name=name, lam=lam, guidance_scale=guidance_scale, autoscale_grads=autoscale_grads, residual_func=residual_func,
        sigma_min=sigma_min, grad_wrt_denoising_output=grad_wrt_denoising_output)
