import torch
import numpy as np
import tqdm
import os
import sacred
import model.util as util
from torch import nn

'''
# Define device
if torch.cuda.is_available():
	DEVICE = "cuda:3"
else:
	DEVICE = "cpu"
'''

class NeuralNetwork(nn.Module):
    def __init__(self,initial_dim):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(initial_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 200),
            nn.ReLU(),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.flatten(x)
        final_value = self.linear_relu_stack(x)
        return final_value
    
    def loss(self, pred_values, true_values):
        # Compute loss as MSE
        squared_error = torch.square(true_values - pred_values)
         
        mean_error = torch.mean(squared_error)
        return(mean_error)

class NeuralNetwork_X(nn.Module):
    def __init__(self, input_size, scale = 1.0, t_limit=1.0, time_embed_std=30,
        embed_size=256):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(embed_size*2, 200),
            nn.ReLU(),
            nn.Linear(200, 200),
            nn.ReLU(), 
            nn.Linear(200, input_size),
        )
        self.t_limit = t_limit 
        self.scale = scale 
        
        self.swish = lambda x: x * torch.sigmoid(x)

        self.time_embed_rand_weights = torch.nn.Parameter(
            torch.randn(embed_size // 2) * time_embed_std,
            requires_grad=False
        )
        # Dense layers to generate time embeddings
        self.time_dense_layers = torch.nn.Sequential(
                torch.nn.Linear(embed_size, embed_size),
                torch.nn.Sigmoid(),
                torch.nn.Linear(embed_size, embed_size)
            ) 
        self.linear_dense_layers = torch.nn.Sequential(
                torch.nn.Linear(input_size, embed_size),
                torch.nn.Sigmoid(),
                torch.nn.Linear(embed_size, embed_size)
            ) 
    def forward(self, x, t ):

        time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi) * \
            self.time_embed_rand_weights[None, :]
        # Shape: B x (E / 2)
        time_embed = self.swish(
            torch.cat([
                torch.sin(time_embed_args), torch.cos(time_embed_args)
            ], dim=1)
        )  # Shape: B x E
        time_embed = self.time_dense_layers(time_embed)
        ####time_embed = torch.tile(    time_embed[:, None], (1, x.shape[1], 1))
        original_size1 = x.shape[1] 
        original_size2 = x.shape[2]  
        x = self.flatten(x)
        x = self.linear_dense_layers(x)

        xt = torch.cat([x, time_embed], dim=1)

        final_value = self.linear_relu_stack(xt)
        return torch.reshape(final_value * self.scale, (x.shape[0],original_size1 ,original_size2))
    
    def loss(self, pred_values, true_values):
        pass  


class Sim_NeuralNetwork_X(nn.Module):
    def __init__(self, t_limit=1.0, time_embed_std=30,
        embed_size=256):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(embed_size*2, 200),
            nn.ReLU(),
            nn.Linear(200, 200),
            nn.ReLU(), 
            nn.Linear(200, 1*2),
        )
        self.t_limit = t_limit 
        
        self.swish = lambda x: x * torch.sigmoid(x)

        self.time_embed_rand_weights = torch.nn.Parameter(
            torch.randn(embed_size // 2) * time_embed_std,
            requires_grad=False
        )
        # Dense layers to generate time embeddings
        self.time_dense_layers = torch.nn.Sequential(
                torch.nn.Linear(embed_size, embed_size),
                torch.nn.Sigmoid(),
                torch.nn.Linear(embed_size, embed_size)
            ) 
        self.linear_dense_layers = torch.nn.Sequential(
                torch.nn.Linear(1*2, embed_size),
                torch.nn.ReLU(),
                torch.nn.Linear(embed_size, embed_size)
            )  



    def forward(self, x, t ):

        time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi) * \
            self.time_embed_rand_weights[None, :]
        # Shape: B x (E / 2)
        time_embed = self.swish(
            torch.cat([
                torch.sin(time_embed_args), torch.cos(time_embed_args)
            ], dim=1)
        )  # Shape: B x E
        time_embed = self.time_dense_layers(time_embed)
        ####time_embed = torch.tile(    time_embed[:, None], (1, x.shape[1], 1))

        x = self.flatten(x)
        x = self.linear_dense_layers(x)

        xt = torch.cat([x, time_embed], dim=1)

        final_value = self.linear_relu_stack(xt)
        return torch.reshape(final_value, (x.shape[0], 1 ,2))
    
    def loss(self, pred_values, true_values, weights=None):
        """
        Computes the loss of the neural network.
        Arguments:
            `pred_values`: a B x L x D tensor of predictions from the network
            `true_values`: a B x L x D tensor of true values to predict
            `weights`: if provided, a tensor broadcastable with B x D to weight
                the squared error by, prior to summing or averaging across
                dimensions
            `mask`: B x L boolean tensor denoting which positions are masked by
                padding
        Returns a scalar loss of mean-squared-error values, summed across the
        D dimension and averaged across the batch dimension.
        """
        # Compute loss as MSE
        squared_error = torch.square(true_values - pred_values)
        if weights is not None:
            squared_error = squared_error / weights

        mean_error = torch.mean(squared_error)
        
        return mean_error

def train_y(model_y, num_epoch=20,batch_size=32,learning_rate=1e-3, DEVICE = "cuda"):
    
    model_y.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(model_y.parameters(), lr=learning_rate)
    whole_losses = []
    whole_losses2 = [] 
    t_iter = tqdm.tqdm(range(num_epoch))

    for ttt in t_iter:
        batch_losses = []
        check =0 
        for local_batch, local_labels in train_dataloader:
            check +=1 
            xxx, yyy = local_batch.to(DEVICE), local_labels.to(DEVICE)
            pred = model_y(xxx)
            loss = model_y.loss(pred, yyy)
            optim.zero_grad()
            loss.backward()
            loss_val = loss.item()
            torch.nn.utils.clip_grad_norm_(model_y.parameters(), 1)
            optim.step()
            batch_losses.append(loss_val)
        
        with torch.no_grad():
            batch_losses2 = []
            for local_batch, local_labels in test_dataloader:
                xxx, yyy = local_batch.to(DEVICE), local_labels.to(DEVICE)
                pred = model_y(xxx)
                loss = model_y.loss(pred, yyy)
                loss_val = loss.item()
                batch_losses2.append(loss_val)

        ####t_iter.set_description("%f" % np.mean(batch_losses))
        t_iter.set_description("%f" % np.mean(batch_losses2[-1])) 
        whole_losses.append(batch_losses[-1])
        whole_losses2.append(np.mean(batch_losses2[-1])) 
    
    return(whole_losses,whole_losses2,check)




def fine_tune(new_model, model, sde, new_model_y, alpha = 0.0001, num_epochs=100, batch_size =128, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda", SAVE = "No"): 

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        new_samples, auxi = finetune_generate_continuous_samples(new_model, model, sde,  num_samples= batch_size, num_steps=50, t_start=0.001,
t_limit=1, verbose=False, DEVICE= DEVICE)
        reward = new_model_y(new_samples) #### Evaluate_reward 
        loss1 = -torch.mean(reward)
        loss2 = torch.mean(auxi) # Add entropy term 
        loss = loss1 + alpha * loss2
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item())

        if kkkk%10 == 0:
            print("%f"% loss1.item())
            print("%f"% loss2.item()) 
    return(batch_losses, batch_losses1, batch_losses2)
            
'''      
def finetune_generate_continuous_samples(
    new_model, model, sde, num_samples=32, num_steps=50, t_start=0.001,
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    t = (torch.ones(num_samples) * t_limit).to(DEVICE)
    xt = sde.sample_prior(num_samples, t)

    # Euler-Maruyama
    time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
    # (descending order)
    step_size = time_steps[0] - time_steps[1]
    
    # Step backward through time starting at xt, simulating the reverse SDE
    x = xt
    t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
    auxi = 0.0 
    print("masa")
    print(x.shape, 1)
    for time_step in t_iter:
        torch.set_grad_enabled(True)
        t = torch.ones(num_samples).to(DEVICE) * time_step
        f = sde.drift_coef_func(x, t)
        g = sde.diff_coef_func(x, t)
        dw = torch.randn_like(x)
        score = model(x, t)
        
        drift = (f - (torch.square(g) * score)) * step_size
        diff = g * torch.sqrt(step_size) * dw
        
        
        #The following is only Change
        
        mean_x = x - drift + new_model(x, t) * step_size # Subtract because step size is really negative # Add new drfit term 
        ggg = torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))  
        print(x.shape, 2)
        print(new_model(x, t).shape, 3)
        print(ggg.shape, 4)

        auxi = auxi + 0.5 * ggg * step_size /(g*g)  # Entropy term 
        x = mean_x + diff 
        
    return mean_x,auxi  # Last step: don't include the diffusion/randomized term
'''

def fine_tune_random(new_model, model, sde, new_model_y, stop_choice = 40, num_steps = 50, alpha = 0.0001, num_epochs=100, batch_size =128, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda"): 

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        new_samples, auxi = finetune_generate_continuous_samples_random(new_model, model, sde,  num_samples= batch_size,stop_choice = stop_choice, num_steps= num_steps, t_start=0.001,
t_limit=1, verbose=False, DEVICE= DEVICE)
        reward = new_model_y(new_samples) #### Evaluate_reward 

        loss1 = -torch.mean(reward)
        loss2 = torch.mean(auxi) # Add entropy term 
        loss = loss1 + alpha * loss2
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item())

        if kkkk%10 == 0:
            print("%f"% loss1.item())
            print("%f"% loss2.item()) 
    return(batch_losses, batch_losses1, batch_losses2)
            
def fine_tune_random2(new_model, model, sde, new_model_y, alpha = 0.0001, num_steps = 50, num_epochs=100, batch_size =128, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda"): 

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        choice = np.random.randint(0,num_steps-1)
        new_samples, auxi = finetune_generate_continuous_samples_random(new_model, model, sde, stop_choice = choice, num_samples= batch_size, num_steps=num_steps, t_start=0.001,
t_limit=1, verbose=False, DEVICE= DEVICE)
        reward = new_model_y(new_samples) #### Evaluate_reward 

        loss1 = -torch.mean(reward)
        loss2 = torch.mean(auxi) # Add entropy term 
        loss = loss1 + alpha * loss2
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item())

        if kkkk%10 == 0:
            print("%f"% loss1.item())
            print("%f"% loss2.item()) 
    return(batch_losses, batch_losses1, batch_losses2)
       

def finetune_generate_continuous_samples(
    new_model, model, sde, num_samples=32, num_steps=50, t_start=0.001,
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    t = (torch.ones(num_samples) * t_limit).to(DEVICE)
    xt = sde.sample_prior(num_samples, t)

    # Euler-Maruyama
    time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
    # (descending order)
    step_size = time_steps[0] - time_steps[1]
    
    # Step backward through time starting at xt, simulating the reverse SDE
    x = xt
    t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
    auxi = 0.0 
    for time_step in t_iter:
        torch.set_grad_enabled(True)
        t = torch.ones(num_samples).to(DEVICE) * time_step
        f = sde.drift_coef_func(x, t)
        g = sde.diff_coef_func(x, t)
        dw = torch.randn_like(x)
        score = model(x, t)
        
        drift = (f - (torch.square(g) * score)) * step_size
        diff = g * torch.sqrt(step_size) * dw
        
        '''
        The following is only Change
        '''
    
        mean_x = x - drift + new_model(x, t) * step_size # Subtract because step size is really negative # Add new drfit term 

        sum_entro = torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))
        g_fla = g.flatten() 
        auxi = auxi + 0.5 * sum_entro * step_size /(g_fla*g_fla)  # Entropy term 
        x = mean_x + diff 
        
    return mean_x,auxi  # Last step: don't include the diffusion/randomized term


def finetune_generate_continuous_samples_random(
    new_model, model, sde, num_samples=32, num_steps=50, t_start=0.001, stop_choice = 25, 
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    t = (torch.ones(num_samples) * t_limit).to(DEVICE)
    xt = sde.sample_prior(num_samples, t)

    # Euler-Maruyama
    time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
    # (descending order)
    step_size = time_steps[0] - time_steps[1]
    
    # Step backward through time starting at xt, simulating the reverse SDE
    x = xt
    t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
    auxi = 0.0 
    kkk = 0 
    for time_step in t_iter:
        
        if kkk <= stop_choice:
            torch.set_grad_enabled(False)
            t = torch.ones(num_samples).to(DEVICE) * time_step
            f = sde.drift_coef_func(x, t)
            g = sde.diff_coef_func(x, t)
            dw = torch.randn_like(x)
            score = model(x, t)
            
            drift = (f - (torch.square(g) * score)) * step_size
            diff = g * torch.sqrt(step_size) * dw
            
            '''
            The following is only Change
            '''
            mean_x = x - drift + new_model(x, t) * step_size # Subtract because step size is really negative # Add new drfit term 
            auxi = auxi + 0.5 * torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2)) * step_size /(g*g)  # Entropy term 
            x = mean_x + diff 
        else: 
            torch.set_grad_enabled(True)
            t = torch.ones(num_samples).to(DEVICE) * time_step
            f = sde.drift_coef_func(x, t)
            g = sde.diff_coef_func(x, t)
            dw = torch.randn_like(x)
            score = model(x, t)
            
            drift = (f - (torch.square(g) * score)) * step_size
            diff = g * torch.sqrt(step_size) * dw
            
            '''
            The following is only Change
            '''
            mean_x = x - drift + new_model(x, t) * step_size # Subtract because step size is really negative # Add new drfit term 
           
            sum_entro = torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))
            g_fla = g.flatten() 
            auxi = auxi + 0.5 * sum_entro * step_size /(g_fla*g_fla)  # Entropy term 
            x = mean_x + diff 
        kkk +=1
         
    return mean_x,auxi  # Last step: don't include the diffusion/randomized term


def finetune_generate_continuous_samples_static(
    new_model, model, sde,  num_samples=64, num_steps=50, t_start=0.001,
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    with torch.no_grad():
        t = (torch.ones(num_samples) * t_limit).to(DEVICE)
        xt = sde.sample_prior(num_samples, t)

        # Euler-Maruyama
        time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
        # (descending order)
        step_size = time_steps[0] - time_steps[1]
        
        # Step backward through time starting at xt, simulating the reverse SDE
        x = xt
        t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
        auxi = 0.0 
        for time_step in t_iter:

            t = torch.ones(num_samples).to(DEVICE) * time_step
            f = sde.drift_coef_func(x, t)
            g = sde.diff_coef_func(x, t)
            dw = torch.randn_like(x)
            score = model(x, t)
            
            drift = (f - (torch.square(g) * score)) * step_size
            diff = g * torch.sqrt(step_size) * dw
            
            '''
            The following is only Change
            '''
            mean_x = x - drift + new_model(x, t) * step_size # Subtract because step size is really negative # Add new drfit term 

            sum_entro = torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))
            g_fla = g.flatten() 
            auxi = auxi + 0.5 * sum_entro * step_size /(g_fla*g_fla)  # Entropy term 
            x = mean_x + diff  
    return mean_x, auxi # Last step: don't include the diffusion/randomized term
        


def fine_tune_initial(new_model,  sde, alpha = 0.0001, num_epochs=100, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda"): 
    
    
    ###output_dir = os.path.join(MODEL_DIR, str(run_num))

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        #####random_t = np.random.randint(100)
        new_samples, auxi = finetune_initial_cosntinuou_samples(new_model, sde, alpha, num_samples=32, num_steps=50, t_start=0.001,
t_limit=1, verbose=False, DEVICE = DEVICE)
        #######new_samples2 = torch.nn.functional.softmax(new_samples,dim=2)
        reward = initial_model_y(new_samples)

        loss1 = -torch.sum(torch.square(reward))
        loss2 = torch.sum(auxi)
        loss = loss1 + loss2
        loss_val = loss1.item()
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item()/alpha)

        if kkkk%10 == 0:
            print("%f"% loss_val )
            print("%f"% loss2.item()/alpha) 
            #######generated_samples = finetune_generate_continuous_samples_static(new_model, model2, sde, num_samples=200, num_steps=50, t_start=0.001,t_limit=1, verbose=False).to(DEVICE)
            plt.xlim((-0.0,1.1))
            plt.hist(initial_model_y(new_samples).cpu().detach().numpy().flatten() )
            plt.show()
    return(batch_losses, batch_losses1, batch_losses2)
            
            

        
        ########epoch_loss = np.mean(batch_losses)
        #######print("Epoch %d average Loss: %.2f" % (epoch_num + 1, epoch_loss))

        ####_run.log_scalar("train_epoch_loss", epoch_loss)
        #####_run.log_scalar("train_batch_losses", batch_losses)
    

def finetune_initial_cosntinuou_samples(
    new_model, sde, alpha,  num_samples=32, num_steps=100, t_start=0.001,
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    t = (torch.ones(num_samples) * t_limit).to(DEVICE)
    xt = sde.sample_prior(num_samples, t)
    x = xt * 0.0

    # Euler-Maruyama
    time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
    # (descending order)
    step_size = time_steps[0] - time_steps[1]
    
    # Step backward through time starting at xt, simulating the reverse SDE
    t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
    auxi = 0.0
    torch.set_grad_enabled(True) 
    for time_step in t_iter:
        t = torch.ones(num_samples).to(DEVICE) * time_step
        dw = torch.randn_like(x)
        diff = torch.sqrt(step_size) * dw  + new_model(x, t) 
        auxi = auxi +  0.5 * alpha *  torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))/(torch.sqrt(step_size)*torch.sqrt(step_size))
        x = x + diff 
        
    return x ,auxi  # Last step: don't include the diffusion/randomized term

def finetune_initial_cosntinuou_samples_static(
    new_model, sde, alpha,  num_samples=32, num_steps=100, t_start=0.001,
    t_limit=1,  verbose=False, initial_option = False, DEVICE = "cuda"
):  
    with torch.no_grad():
        t = (torch.ones(num_samples) * t_limit).to(DEVICE)
        xt = sde.sample_prior(num_samples, t)
        x = xt * 0.0

        # Euler-Maruyama
        time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
        # (descending order)
        step_size = time_steps[0] - time_steps[1]
        
        # Step backward through time starting at xt, simulating the reverse SDE
        t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
        auxi = 0.0
        if initial_option == False: 
            for time_step in t_iter:
                t = torch.ones(num_samples).to(DEVICE) * time_step
                dw = torch.randn_like(x)
                diff = torch.sqrt(step_size) * dw  + new_model(x, t) 
                auxi = auxi +  0.5 * alpha *  torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2))/(torch.sqrt(step_size)*torch.sqrt(step_size))
                x = x + diff 
        else: 
            for time_step in t_iter:
                t = torch.ones(num_samples).to(DEVICE) * time_step
                dw = torch.randn_like(x)
                diff = torch.sqrt(step_size) * dw 
                x = x + diff 
            

    return x ,auxi  # Last step: don't include the diffusion/randomized term


def finetune_PPO(new_model, model, sde, new_model_y, alpha , num_epochs=20, learning_rate = 1e-2, num_samples =1024, t_start=0.001, num_steps = 50, t_limit=1, DEVICE = "cuda"):  
    model.train()
    with torch.no_grad():
        time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)

    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = []
    batch_losses2 = []

    for epoch_num in range(num_epochs):
        kkkk += 1
        #####random_t = np.random.randint(num_steps)

        with torch.no_grad():
            new_whole_samples, new_noise, new_diffusion, last_x, auxi = finetune_generate_continuous_samples_static_whole(new_model, model, sde, num_samples, num_steps, t_start,t_limit, verbose=False, DEVICE = DEVICE)
            
            new_whole_samples = new_whole_samples.to(DEVICE)
            new_noise = new_noise.to(DEVICE)
            new_diffusion = new_diffusion.to(DEVICE)
            last_x = last_x.to(DEVICE)
            auxi = auxi.to(DEVICE)
            #####old_terms = old_terms.to(DEVICE)
            
            ####generated_samples2 = sample_to_seqs(last_x.cpu().numpy())
            ####generated_samples3 = feature_util.one_hot_to_seqs(generated_samples2)
            ######reward = torch.tensor([answer_dict[kkk][0] for kkk in generated_samples3 ]).to(DEVICE) 
            reward = new_model_y(last_x) 
            old_terms = []
            for jjj in range(num_steps):
                ttt = torch.ones(num_samples).to(DEVICE) * time_steps[jjj] 
                ppp = new_model(new_whole_samples[jjj], ttt )
                old_terms.append( ppp.to(DEVICE) )

        torch.set_grad_enabled(True)

        loss = 0 
        penalty = 0.0
        
        for random_t in range(num_steps):

            ttt = torch.ones(num_samples).to(DEVICE) * time_steps[random_t]
        
            intem_old = -torch.square(new_noise[random_t]) * 0.5 
            log_old_probability = torch.sum(intem_old, dim = (1,2) )

            new_reshape = new_model(new_whole_samples[random_t],ttt)  
            intem_new = -torch.square(new_noise[random_t]-(new_reshape - old_terms[random_t])/new_diffusion[random_t]) * 0.5

           
            penalty = penalty + torch.square(new_reshape/new_diffusion[random_t]) * 1.0/num_steps * 1.0/num_steps

            log_new_probability = torch.sum(intem_new, dim = (1,2))  
            
            ratio = torch.exp(log_new_probability-log_old_probability) 

            penalty_t = torch.square(new_reshape/new_diffusion[random_t]) * 1.0/num_steps * 1.0/num_steps 
            penalty_t = 0.5 *  torch.sum(penalty_t, dim = (1,2))  
            psuedo_reward =   reward  #####- alpha * penalty_t 
            loss = loss  - torch.mean(torch.min(reward   * ratio, reward * torch.clamp(ratio , min = 1.0 - 0.1, max = 1.0 + 0.1) )) * 1.0/num_steps

        kl_term = 0.5 * torch.mean(torch.sum(penalty, dim = (1,2))) 
        kl_term2 = alpha * kl_term
        loss = loss + kl_term2
        ###loss = loss  +  torch.mean(kl_term2 * torch.clamp(ratio , min = 1.0 - 0.1, max = 1.0 + 0.1) )
 

        if kkkk%10 == 0:
            ####print((new_reshape.reshape(num_samples,8,4) - old_terms[random_t].reshape(num_samples,8,4)) /new_diffusion[random_t] )
            ###print("Batch_loss %f"%np.mean(batch_losses))
            ####print("Ratio %f"%torch.mean(ratio).cpu().detach().numpy())
            print("Reward %f"%torch.mean(reward).cpu().detach().numpy())
            print("KL term %f"%torch.mean(kl_term).cpu().detach().numpy()) 
            print("KL term %f"%torch.mean(auxi).cpu().detach().numpy())  

        loss_val = loss.item()

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()

        batch_losses.append(torch.mean(reward).cpu().detach().numpy())
        batch_losses2.append(torch.mean(kl_term).cpu().detach().numpy())
        
    return batch_losses, batch_losses2 

    
    

def finetune_generate_continuous_samples_static_whole(
    new_model, model, sde, num_samples=32, num_steps=50, t_start=0.001,
    t_limit=1,  verbose=False, DEVICE = "cuda"
):  
    t = (torch.ones(num_samples) * t_limit).to(DEVICE)
    xt = sde.sample_prior(num_samples, t)
    whole_x = torch.ones((num_steps,xt.shape[0], xt.shape[1], xt.shape[2] ) )
    whole_diff = torch.ones((num_steps,xt.shape[0], xt.shape[1], xt.shape[2] ) )
    whole_noise = torch.ones((num_steps,xt.shape[0], xt.shape[1], xt.shape[2] ) ) 
    whole_existing = torch.ones((num_steps,xt.shape[0], xt.shape[1], xt.shape[2] ) )  

    with torch.no_grad():
    # Euler-Maruyama
        time_steps = torch.linspace(t_limit, t_start, num_steps).to(DEVICE)
        # (descending order)
        step_size = time_steps[0] - time_steps[1]
        
        # step backward through time starting at xt, simulating the reverse sde
        x = xt
        t_iter = tqdm.tqdm(time_steps) if verbose else time_steps
        ttt = 0
        auxi = 0.0  
        for time_step in t_iter:
            whole_x[ttt] =  x

            t = torch.ones(num_samples).to(DEVICE) * time_step
            f = sde.drift_coef_func(x, t)
            g = sde.diff_coef_func(x, t)
            dw = torch.randn_like(x)
            score = model(x, t)
            
            drift = (f - (torch.square(g) * score)) * step_size
            diff = g * torch.sqrt(step_size) * dw    
            mean_x = x - drift + new_model(x, t) * step_size  # subtract because step size is really negative
      

            ### Save  
            whole_diff[ttt] = g * torch.sqrt(step_size)
            whole_noise[ttt] = dw
            ######whole_existing[ttt] = new_model(x, t).reshape(xt.shape[0],xt.shape[1], xt.shape[2])
            g_fla = g.flatten()
            auxi = auxi + 0.5 * torch.sum(new_model(x, t) * new_model(x, t), dim =  (1,2)) * step_size /(g_fla*g_fla)  # Entropy term 
            #######
            ttt +=1 
            x = mean_x + diff

    return whole_x, whole_noise, whole_diff, x, auxi  # Last step: don't include the diffusion/randomized term     




def fine_tune_boot(new_model, model, sde, new_model_y, alpha = 0.0001, num_epochs=100, batch_size =128, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda", SAVE = "No"): 

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        new_samples, auxi = finetune_generate_continuous_samples(new_model, model, sde,  num_samples= batch_size, num_steps=50, t_start=0.001,
t_limit=1, verbose=False, DEVICE= DEVICE)
        reward1 = new_model_y[0](new_samples) #### Evaluate_reward 
        reward2 = new_model_y[1](new_samples)
        reward3 = new_model_y[2](new_samples)


        loss1 = -torch.mean(torch.max(torch.max(reward1,reward2),reward3))
        loss2 = torch.mean(auxi) # Add entropy term 
        loss = loss1 + alpha * loss2
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item())

        if kkkk%10 == 0:
            print("%f"% loss1.item())
            print("%f"% loss2.item()) 
    return(batch_losses, batch_losses1, batch_losses2)



def fine_tune_UCB(new_model, model, sde, new_model_y, alpha = 0.0001, beta = 0.01, num_epochs=100, batch_size =128, learning_rate = 1e-2,  t_limit=1, DEVICE = "cuda", SAVE = "No"): 

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(new_model.parameters(), lr=learning_rate)

    kkkk = 0 
    batch_losses = [] 
    batch_losses1 = [] 
    batch_losses2 = [] 
    for epoch_num in range(num_epochs):
        kkkk += 1
        new_samples, auxi = finetune_generate_continuous_samples(new_model, model, sde,  num_samples= batch_size, num_steps=50, t_start=0.001,
t_limit=1, verbose=False, DEVICE= DEVICE)
        reward = new_model_y(new_samples) #### Evaluate_reward 

        loss1 = -torch.mean(reward)- beta * new_model_y.bonus(new_samples) 
        loss2 = torch.mean(auxi) # Add entropy term 
        loss = loss1 + alpha * loss2
        
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(new_model.parameters(), 1)
        optim.step()
        
        batch_losses.append(loss.item())
        batch_losses1.append(loss1.item())
        batch_losses2.append(loss2.item())

        if kkkk%10 == 0:
            print("%f"% loss1.item())
            print("%f"% loss2.item()) 
    return(batch_losses, batch_losses1, batch_losses2)