import torch
import numpy as np

from .loss import weighted_KSD
import time


# Code inspired on https://github.com/MatthewAlexanderFisher/MTKSD

def train_weighted_KSD(sample, learner, weights, gamma=.1, n_steps=10000, save_out=False, m=200,
              print_loss=False, lr=1e-3, batch=100):
    loss_vec = torch.zeros(n_steps)

    timings = np.zeros(n_steps // m)
    start = time.time()
    iter_num = []

    optimizer = torch.optim.SGD(learner.parameters(), lr=lr)  

    for i in range(n_steps):
        optimizer.zero_grad()

        loss = weighted_KSD(sample, learner.score, weights, gamma)

        loss.backward()
        optimizer.step()
        loss_vec[i] = loss.clone().detach()

        if (i + 1) % m == 0 and save_out is True:
            timings[i // m] = time.time() - start
            iter_num.append(i)
 

    return 
