import numpy as np


def weights_error(x, true):
    return ((np.exp(true) - np.exp(x)) ** 2).mean()


def optimize_easy(d, lapl_scale, log_isweights_beta, synth_data, comp_mean):
    thr = comp_mean.mean()
    new_error = 1000
    while True:
        old_error = new_error
        losses = []
        for _ in range(1000):
            lapl_noice_sim = np.random.laplace(loc=0, scale=lapl_scale, size=(d + 1))
            log_isweights_noise = (
                log_isweights_beta + lapl_noice_sim[0] + synth_data @ lapl_noice_sim[1:]
            )
            log_isweights_noise[comp_mean > thr] += comp_mean[comp_mean > thr]
            losses.append(weights_error(log_isweights_noise, log_isweights_beta))
        new_error = np.mean(losses)
        if new_error > old_error + 1e-6:
            thr += 0.0001
            break
        else:
            thr -= 0.0001

    return thr


# import sklearn
# import torch
# from pytorch_lightning import LightningModule, Trainer
# from torch.utils.data import TensorDataset, DataLoader


# def weights_error(x, true):
#     return ((torch.exp(true) - torch.exp(x)) ** 2).mean()


# class FindOptThr(LightningModule):
#     def __init__(
#         self, input_dim, lapl_scale, log_isweights_beta, synth_data, comp_mean
#     ):
#         super(FindOptThr, self).__init__()

#         self.thr = torch.nn.Parameter(torch.tensor(-0.0013))
#         self.thr.requires_grad = True

#         self.lapl_scale = lapl_scale
#         self.d = input_dim
#         self.log_isweights_beta = torch.from_numpy(log_isweights_beta)
#         self.synth_data = torch.from_numpy(synth_data)
#         self.comp_mean = torch.from_numpy(comp_mean)

#     def forward(self, log_isweights_noise):
#         # log_isweights_noise[self.comp_mean > self.thr] += self.comp_mean[
#         #     self.comp_mean > self.thr
#         # ]
#         return log_isweights_noise.max(self.thr)

#     def training_step(self, batch, batch_idx):
#         losses = []
#         for _ in range(1000):
#             lapl_noice_sim = np.random.laplace(
#                 loc=0, scale=self.lapl_scale, size=(self.d + 1)
#             )
#             log_isweights_noise = (
#                 self.log_isweights_beta
#                 + lapl_noice_sim[0]
#                 + self.synth_data @ lapl_noice_sim[1:]
#             )
#             log_isweights_noise = self.forward(log_isweights_noise)
#             losses.append(weights_error(log_isweights_noise, self.log_isweights_beta))

#         return torch.mean(torch.stack(losses))

#     def configure_optimizers(self):
#         return torch.optim.Adam(self.parameters(), lr=0.001)

#     def fit(self, **kwargs):

#         train_ds = TensorDataset(self.synth_data)

#         train_loader = DataLoader(train_ds, batch_size=128)

#         trainer = Trainer(
#             # auto_select_gpus=True if kwargs["gpus"] > 0 else False,
#             **kwargs,
#         )

#         # Train the model ⚡
#         trainer.fit(self, train_loader)
