########################################
#
# R2D2
# scalar Gaussian demo
# by
# ~
# Implementation of demo from "An Optimal Diffusion Approach to Quadratic Rate-Distortion Problems: New Solution and Approximation Methods" (2025)
# by anonymous author(s).
#
# to execute run: python ./R2D2_v1.py
#
########################################

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm, trange

import matplotlib
from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.size": 60,
    "xtick.labelsize" : 58,
    "ytick.labelsize": 58,
    'axes.labelsize': 68,

})

def weights_init(m):

    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        torch.nn.init.xavier_normal_(m.weight.data, gain=1.)
        if not m.bias is None:
            m.bias.data.fill_(0)

def moments(x1):

    avg_x1 = torch.mean(x1, dim = 0)

    x1_  = x1 - avg_x1
    sig_x1 = (x1_.T@x1_) / (x1_.size(0) - 1)

    # in the 1d case, det is just sig_x1
    det = torch.det(sig_x1)
    assert det > 1e-13, "zero determinant!"

    logDet = torch.log( det )
    return avg_x1, sig_x1,logDet

## Euler-Maruyama sampling
def EuMa(unet, x0, eps, dt, T = 1., device = None, terminal_noise=True):

        enrg = 0

        with torch.no_grad():

            sq_eps = torch.sqrt(eps)
            xt = x0
            eps_col = torch.zeros_like(x0) + eps
            one_col = torch.zeros_like(x0) + 1

        t   = 0

        while t < T:

            dT  = torch.rand_like(eps)*dt*2
            dT  = torch.clip(dT, max = 1-t)
            sdT = torch.sqrt(dT)

            uin = torch.hstack([xt,t*one_col*2-1,eps_col,one_col*dT/dt-1])
            ft = unet(uin)
            zt = torch.randn_like(xt, requires_grad=False)*sdT*sq_eps
            xt = xt + ft*dT +  zt
            t += dT

            energy = torch.sum(ft**2, dim = 1)
            enrg += torch.sum( energy*dT )

            torch.cuda.empty_cache()

        return enrg/x0.size(0), xt

## Fully connected model
class FC_DNN(nn.Module):
    def __init__(self,input, output, args = None, out_bias=True):
        super(FC_DNN, self).__init__()

        hsize = 128
        lr = 1e-3
        self.clip = 1.

        self.model = nn.Sequential()
        self.model.add_module('Linear_input' ,nn.Linear(input, hsize, bias  = True))
        self.model.add_module('relu%d'%(0),nn.LeakyReLU(0.2, inplace=True))

        for i in range(1):
            self.model.add_module('Linear%d'%(i+1),          nn.Linear(hsize, hsize))
            self.model.add_module('relu%d'%(i+1),nn.LeakyReLU(0.2, inplace=True))


        self.model.add_module('Linear_output' ,nn.Linear(hsize, output, bias  = out_bias))

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, betas=(0.9, .999), eps=1e-08, weight_decay=0)

    def train_fn(self, loss_in):
        self.optimizer.zero_grad()
        loss_in.backward()

        torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)

        self.optimizer.step()

    def forward(self,input):
        return self.model(input)

## R2D2 scalar Gaussian model
class R2D2_Gauss(nn.Module):
    def __init__(self, args = None):
        super(R2D2_Gauss, self).__init__()

        self.Unet = FC_DNN(input = 4, output=1)
        self.Znet = FC_DNN(input = 2, output=1, out_bias=False)

        self.Unet.apply(weights_init)
        self.Znet.apply(weights_init)

        self.bsize = 512
        self. eps_min = .025
        self. eps_max = .975

        self.step_n = 0

    def sample(self,bsize, device):
        x0 = torch.randn((bsize,1)).to(device)
        return x0

    def U_loss(self, x0, eps):

        mean_enrgy, x1 = EuMa(self.Unet, x0, eps, dt=.01)
        _, _,logDet = moments(x1)
        H = .5*logDet - self.Negentropy(x1,eps)

        loss_u = mean_enrgy + 2*eps*H
        return loss_u

    def Negentropy(self, x1, eps):

        eps_col = torch.zeros_like(x1) + eps
        avg_x1, sig_x1, _ = moments(x1)

        z = avg_x1 + torch.randn_like(x1)@torch.sqrt(sig_x1).float()

        x_in = torch.hstack([x1,eps_col])
        zreal  = self.Znet(x_in)
        z_in = torch.hstack([z,eps_col])
        zgauss = self.Znet(z_in)

        negent = torch.mean(zreal) - torch.log(torch.mean(torch.exp(zgauss))+1e-11)

        return negent

    def Z_loss(self,x0, eps):

        _, x1 = EuMa(self.Unet, x0, eps, dt=.01)

        negent = self.Negentropy(x1,eps)
        return -negent

    def train_step(self, device):
        do_train_u = (self.step_n % 4 == 0)
        do_train_z = not (self.step_n % 4 == 0)

        x0 = self.sample(self.bsize, device)

        eps = torch.rand(1).to(device)
        eps *= (self.eps_max - self.eps_min)
        eps += self.eps_min

        if do_train_u:
            loss = self.U_loss(x0, eps)
            self.Unet.train_fn(loss)
        if do_train_z:
            loss = self.Z_loss(x0, eps)
            self.Znet.train_fn(loss)

        self.step_n += 1
        return loss

    def eval_step(self, eps, bsize, device,x0=None):

        if x0 is None:
            x0 = self.sample(bsize,device)

        mean_enrgy, x1 = EuMa(self.Unet, x0, eps, dt=.01)
        _, _, logDet = moments(x1)
        H = .5*logDet - self.Negentropy(x1,eps)

        loss = mean_enrgy + 2*eps*H

        D = torch.mean((x1-x0)**2)
        R = (loss - D)/(2*eps) - .5*torch.log(eps) +.5
        return R,D/2


def eval_R2D2(R2D2, device, n_steps = 1, seed = None, testset = None):

    R2D2.eval()

    eps_ln = 2*R2D2.eps_min + torch.linspace(0.,(R2D2.eps_max-2*R2D2.eps_min - .05), 7)
    # eps_ln = 1/torch.Tensor([1/.9,2,5,7,10,15])

    eps_ln = eps_ln.to(device)

## evaluate
    R = []; D = [];
    errR = []; errD = [];
    E =[]; errE =[]

    q = np.array([.25,.5,.75])

    for eps in eps_ln:
        Re = []; De = []; Ee = []
        if not seed is None:
            np.   random.seed(seed)
            torch.manual_seed(seed)
        for n in range(n_steps):
            r,d = R2D2.eval_step(eps, 1024, device,x0=testset[n])
            Re.append(r.item())
            De.append(d.item())

            Ee.append( Re[-1] + .5*np.log(2*De[-1]))

        De = np.stack(De)
        Re = np.stack(Re)
        Ee = np.stack(Ee)

        R.append(np.median(Re))
        D.append(np.median(De))
        E.append(np.median(np.abs(Ee)))

        d=np.quantile(De,q,axis=0)
        r=np.quantile(Re,q,axis=0)
        e=np.quantile(np.abs(Ee),q,axis=0)

        errD.append([d[1]-d[0],d[-1]-d[1]])
        errR.append([r[1]-r[0],r[-1]-r[1]])
        errE.append(e)

    return np.array(D), np.array(R), errD,errR,E

## Main routine
def train_experiment(train_steps_n,EVAL_STEPS,test_set,seed,device):

    np.   random.seed(seed)
    torch.manual_seed(seed)

    R2D2 = R2D2_Gauss().to(device)

    R2D2.train()


    TRNG = trange(train_steps_n)
    for t in TRNG:
        R2D2.train_step(device = device)

    return R2D2

## MAIN
if __name__ == '__main__':

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # device = torch.device('cpu')

    N_EXP = 64
    meta_seed = 1337
    
    Train_steps   = 25000
    Eval_steps    = 1
    
    fle_name = 'WGD_NERD_v1_64exp'
    np.   random.seed(meta_seed)
    torch.manual_seed(meta_seed)

    seeds         = np.arange(meta_seed,meta_seed+N_EXP) 
        
    R2D2 = R2D2_Gauss().to(device)
    test_set = [R2D2.sample(4096,device) for _ in range(Eval_steps)]

    Ds = []; Rs = [];
    wgdDs = []; wgdRs = []
    nrdDs = []
    nrdRs = []
        
    for seed in seeds:
        trained_model = train_experiment(
        Train_steps,Eval_steps,test_set,seed,device
        )
        d,r,_,_ ,_= eval_R2D2(trained_model, device  = device, n_steps = Eval_steps, seed = seed, testset = test_set)
        Ds.append(d); Rs.append(r)
        del trained_model
        torch.cuda.empty_cache()


    Ds = np.stack(Ds)
    Rs = np.stack(Rs)
    errs = Rs + .5*np.log(2*Ds)
    
    
    q = [.25,.5,.75]
    D = np.quantile(Ds,q,axis=0).T
    R = np.quantile(Rs,q,axis=0).T
    ERRS = np.quantile(np.abs(errs),q,axis=0).T

    
    ## load 
    
    import pickle

    with open(fle_name+'.pkl', 'rb') as pickle_file:
        results = pickle.load(pickle_file)
       
    # results obtained by running
    # WGD and NERD codes 
    wgdDs = results['wgdDs']
    wgdRs = results['wgdRs']
    wgd_errs = results['wgd_errs']
    
    nrdDs = results['nrdDs']
    nrdRs = results['nrdRs']
    nrd_errs = results['nrd_errs']
    
    wgdD = np.quantile(wgdDs,q,axis=0).T
    wgdR = np.quantile(wgdRs,q,axis=0).T
    wgdERRS = np.quantile(np.abs(wgd_errs),q,axis=0).T
    wgd_d = wgdD[1,:] 
    wgd_r = wgdR[1,:] 

    nrdD = np.quantile(nrdDs,q,axis=0).T
    nrdR = np.quantile(nrdRs,q,axis=0).T
    nrdERRS = np.quantile(np.abs(nrd_errs),q,axis=0).T
    
## plot R(D)
    fig, ax1 = plt.subplots(1,1)
    ax1.plot(D[:,1],R[:,1],'b:',  zorder=1,marker = "^", markersize = 15, lw=7, label = "$\mathsf{R2D2}$ (ours)")

    ErrR =  np.stack([R[:,1]-R[:,0],R[:,-1]-R[:,1]])
    ErrD =  np.stack([D[:,1]-D[:,0],D[:,-1]-D[:,1]])
    ax1.errorbar(D[:,1], R[:,1], yerr=ErrR, xerr=ErrD, capsize=5,ecolor='r',elinewidth=8,zorder=1)

    nerd = ax1.scatter(nrdD[:,1],nrdR[:,1],marker = 'o',color='r', zorder=2 , s = 150, label = "$\mathsf{NERD}$ [Lei et al.22]")
    nerd.set_facecolor('none')
    nerd.set_edgecolor('r')

    wgdscat = ax1.scatter(wgdD[:,1],wgdR[:,1], marker = 'd',color='k', zorder=2 , s = 150, label = "$\mathsf{WGD}$ [Yang et al.24]")

    wgdscat.set_facecolor('none')
    wgdscat.set_edgecolor('r')

    Dmax1 = np.max(D[:,1])
    Dmax2 = np.max(nrdD[:,1])
    Dmax3 = np.max(wgd_d/2)
    Dmax = 1.05*np.max([Dmax1,Dmax2,Dmax3])

    Dmin1 = np.min(D[:,1])
    Dmin2 = np.min(nrdD[:,1])
    Dmin3 = np.min(wgd_d/2)
    Dmin = .75*np.min([Dmin1,Dmin2,Dmin3])

    Dline = np.linspace(Dmin,Dmax,101)
    ax1.plot(Dline,-.5*np.log(2*Dline),'k',  zorder=3,marker = "", markersize = 15, lw=6, label = "$R(D)$ (groundtruth)")

    ax1.legend()
    plt.ylabel(r'$R(D)$ [nat]')
    plt.xlabel(r'${1 \over 2}$MSE')

## plot errata
    fig, ax2 = plt.subplots(1,1)
    R = np.array(R)
    D = np.array(D)
    ERR1 = ERRS[:,1] #R + .5*np.log(2*D)
    ax2.plot(D[:,1],np.abs(ERR1),'b',  zorder=1, lw=7, label = "$\mathsf{R2D2}$ (ours)")
    ax2.fill_between(D[:,1], y1=ERRS[:,-1], y2=ERRS[:,0], where=None, interpolate=False, step=None, data=None, color='b', alpha=.3)

    nrdR = np.array(nrdR)
    nrdD = np.array(nrdD)
    ERR1 = nrdERRS[:,1] #R + .5*np.log(2*D)
    ax2.plot(nrdD[:,1],np.abs(ERR1),'r',  zorder=1, lw=7, label = "$\mathsf{NERD}$")
    ax2.fill_between(nrdD[:,1], y1=nrdERRS[:,-1], y2=nrdERRS[:,0], where=None, interpolate=False, step=None, data=None, color='r', alpha=.3)


    wgdR = np.array(wgdR)
    wgdD = np.array(wgdD)
    ERR1 = wgdERRS[:,1] 
    
    ax2.plot(wgdD[:,1],np.abs(ERR1),'k',  zorder=1, lw=7, label = "$\mathsf{WGD}$")
    ax2.fill_between(wgdD[:,1], y1=wgdERRS[:,-1], y2=wgdERRS[:,0], where=None, interpolate=False, step=None, data=None, color='k', alpha=.3)


    ax2.set_ylabel(r'$|\widehat{R(D)} - R(D)|$')
    ax2.set_xlabel(r'$\frac{1}{2}MSE$')
    ax2.set_xlim([0,.4])
    
    ax2.legend()
    plt.show()
