import tensorflow as tf
import keras
import numpy as np
import pandas as pd
import time
from MPCclass import MPCNetwork

class MyReduceLR(keras.callbacks.Callback):
    """
    Custom callback class to reduce learning rate based on validation loss.

    Args:
        monitor (str): The metric to monitor. Defaults to 'val_loss'.
        factor (float): The factor by which the learning rate will be reduced. Defaults to 0.5.
        patience (int): The number of epochs with no improvement after which the learning rate will be reduced. Defaults to 0.
        min_lr (float): The minimum learning rate. Defaults to 1e-5.
    """

    def __init__(self, monitor='val_loss', factor=0.5, patience=0, min_lr=1e-5):
        """
        Initializes the MyReduceLR callback.

        Args:
            monitor (str): The metric to monitor. Defaults to 'val_loss'.
            factor (float): The factor by which the learning rate will be reduced. Defaults to 0.5.
            patience (int): The number of epochs with no improvement after which the learning rate will be reduced. Defaults to 0.
            min_lr (float): The minimum learning rate. Defaults to 1e-5.
        """
        super().__init__()
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr

    def on_train_begin(self, logs=None):
        """
        Method called at the beginning of training.

        Args:
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        self.wait = 0
        self.lr = self.model.optimizer.learning_rate

    def on_epoch_end(self, epoch, logs=None):
        """
        Method called at the end of each epoch.

        Args:
            epoch (int): The current epoch index.
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        logs['lr'] = self.lr.numpy()
        if epoch == 0:
            self.last_record = logs[self.monitor]
        else:
            record = logs[self.monitor]
            if self.last_record < record:
                self.wait += 1
                if self.wait > self.patience:
                    self.lr.assign(max(self.lr * self.factor, self.min_lr))
                    self.wait = 0
            self.last_record = record

class get_callback(keras.callbacks.Callback):
    """
    Custom callback class for logging during training.

    Args:
        period (str): The period at which to log. Can be 'epoch' or an integer representing the number of batches.

    Attributes:
        period (str): The period at which to log.
        log (list): List to store the logged values.
        count_batch (int): Counter for the number of batches.

    Methods:
        do_log: Abstract method to be implemented by subclasses for logging.
        on_train_begin: Method called at the beginning of training.
        on_batch_end: Method called at the end of each batch.
        on_epoch_end: Method called at the end of each epoch.
        on_train_end: Method called at the end of training.
    """

    def __init__(self, period='epoch'):
        super().__init__()
        self.period = period

    def do_log(self, logs):
        """
        Abstract method to be implemented by subclasses for logging.

        Args:
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        raise NotImplementedError

    def on_train_begin(self, logs=None):
        """
        Method called at the beginning of training.

        Args:
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        self.log = []
        self.count_batch = 0

    def on_batch_end(self, batch, logs=None):
        """
        Method called at the end of each batch.

        Args:
            batch (int): The current batch index.
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        if self.period == 'epoch':
            return
        if not (self.count_batch % self.period):
            self.do_log(logs)
        self.count_batch += 1

    def on_epoch_end(self, epoch, logs=None):
        """
        Method called at the end of each epoch.

        Args:
            epoch (int): The current epoch index.
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        if self.period == 'epoch':
            self.do_log(logs)

    def on_train_end(self, logs=None):
        """
        Method called at the end of training.

        Args:
            logs (dict): Dictionary containing the logs for the current batch or epoch.
        """
        if self.period == 'epoch':
            self.log = np.array(self.log)
            return
        self.do_log(logs)
        self.log = np.array(self.log)

class get_callback(keras.callbacks.Callback):
    """
    Custom callback class for logging during training.

    Args:
        period (str): The period at which to log. Can be 'epoch' or an integer representing the number of batches.

    Attributes:
        period (str): The period at which to log.
        log (list): List to store the logged values.
        count_batch (int): Counter for the number of batches.

    Methods:
        do_log: Abstract method to be implemented by subclasses for logging.
        on_train_begin: Callback method called at the beginning of training.
        on_batch_end: Callback method called at the end of each batch.
        on_epoch_end: Callback method called at the end of each epoch.
        on_train_end: Callback method called at the end of training.
    """

    def __init__(self, period='epoch'):
        super().__init__()
        self.period = period

    def do_log(self, logs):
        """
        Abstract method to be implemented by subclasses for logging.

        Args:
            logs (dict): Dictionary containing the metrics for the current batch/epoch.
        """
        raise NotImplementedError

    def on_train_begin(self, logs=None):
        """
        Callback method called at the beginning of training.

        Args:
            logs (dict, optional): Dictionary containing the metrics for the current batch/epoch.
        """
        self.log = []
        self.count_batch = 0

    def on_batch_end(self, batch, logs=None):
        """
        Callback method called at the end of each batch.

        Args:
            batch (int): The current batch index.
            logs (dict, optional): Dictionary containing the metrics for the current batch/epoch.
        """
        if self.period == 'epoch':
            return
        if not (self.count_batch % self.period):
            self.do_log(logs)
        self.count_batch += 1

    def on_epoch_end(self, epoch, logs=None):
        """
        Callback method called at the end of each epoch.

        Args:
            epoch (int): The current epoch index.
            logs (dict, optional): Dictionary containing the metrics for the current batch/epoch.
        """
        if self.period == 'epoch':
            self.do_log(logs)

    def on_train_end(self, logs=None):
        """
        Callback method called at the end of training.

        Args:
            logs (dict, optional): Dictionary containing the metrics for the current batch/epoch.
        """
        if self.period == 'epoch':
            self.log = np.array(self.log)
            return
        self.do_log(logs)
        self.log = np.array(self.log)

        
class get_loss_callback1(get_callback):
    """
    Callback class to compute and log the loss during training.

    Args:
        x: Input data.
        y: Target data.
        period: Frequency at which to log the loss (default: 'epoch').

    Attributes:
        x: Input data.
        y: Target data.
    """

    def __init__(self, x, y, period='epoch'):
        super().__init__(period=period)
        self.x = x
        self.y = y

    def do_log(self, logs):
        """
        Compute and log the loss.

        Args:
            logs: Dictionary containing the current training metrics.
        """
        loss = []
        lossblock = self.model.lossblocks[-1]
        x = self.model.layers[0](self.x)
        for l in self.model.blocks:
            x = l(x)
            loss.append(tf.reduce_mean(self.model.loss(self.y, lossblock(x))))
        self.log.append(np.array(loss))

        
class get_loss_callback(get_callback):
    """
    Callback class to log loss and accuracy during training.

    Args:
        period (str): The period at which the callback should be called. Default is 'epoch'.

    Attributes:
        log (list): A list to store the loss and accuracy values.

    Methods:
        do_log(logs): Logs the loss and accuracy values.

    """
    def __init__(self, period='epoch'):
        super().__init__(period=period)
        self.log = []

    def do_log(self, logs):
        """
        Logs the loss and accuracy values.

        Args:
            logs (dict): A dictionary containing the loss and accuracy values.

        """
        self.log.append([logs.get('loss'), logs.get('acc')])
    

class get_memory_callback(get_callback):
    """
    Callback class to monitor memory usage during training.

    Args:
        period (str): The period at which memory usage is logged. Default is 'epoch'.
        device (str): The device on which memory usage is monitored. Default is 'GPU:0'.

    Attributes:
        device (str): The device on which memory usage is monitored.
        log (list): A list to store memory usage information.

    Methods:
        on_train_begin(logs=None): Called at the beginning of training.
        do_log(logs): Logs memory usage at the specified period.
    """

    def __init__(self, period='epoch', device='GPU:0'):
        super().__init__(period=period)
        self.device = device
        

    def on_train_begin(self, logs=None):
        """
        Called at the beginning of training.

        Args:
            logs (dict): Dictionary of logs.

        Returns:
            None
        """
        super().on_train_begin(logs)
        self.size0=tf.config.experimental.get_memory_info(self.device)['current']
        self.log = [tf.config.experimental.get_memory_info(self.device)['peak']-self.size0]

    def do_log(self, logs):
        """
        Logs memory usage at the specified period.

        Args:
            logs (dict): Dictionary of logs.

        Returns:
            None
        """
        self.log.append(tf.config.experimental.get_memory_info(self.device)['peak']-self.size0)
    
class get_time_callback(get_callback):
    """
    A callback class to measure the time taken during training.

    Args:
        period (str): The period at which the time should be measured. Default is 'epoch'.

    Attributes:
        time0 (float): The starting time of the training.
        log (list): A list to store the time taken for each period.

    Methods:
        on_train_begin(logs=None): Called at the beginning of training.
        do_log(logs): Logs the time taken for each period.

    """

    def __init__(self, period='epoch'):
        super().__init__(period=period)
        self.time0 = time.perf_counter()
        self.log = []

    def on_train_begin(self, logs=None):
        """
        Called at the beginning of training.

        Args:
            logs (dict): Dictionary of logs.

        """
        super().on_train_begin(logs)
        self.time0 = time.perf_counter()

    def do_log(self, logs):
        """
        Logs the time taken for each period.

        Args:
            logs (dict): Dictionary of logs.

        """
        self.log.append(time.perf_counter() - self.time0)
    
class TimingCallback(keras.callbacks.Callback):
    """
    Callback for tracking the time taken for each epoch during training.

    Attributes:
        times (list): List to store the time taken for each epoch.

    Methods:
        on_train_begin: Called at the beginning of training.
        on_epoch_begin: Called at the beginning of each epoch.
        on_epoch_end: Called at the end of each epoch.
    """

    def on_train_begin(self, logs={}):
        """
        Called at the beginning of training.
        Initializes the `times` list.
        """
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        """
        Called at the beginning of each epoch.
        Starts the timer.
        """
        self.starttime = time.time()

    def on_epoch_end(self, epoch, logs={}):
        """
        Called at the end of each epoch.
        Calculates the time taken for the epoch and appends it to the `times` list.
        """
        self.times.append(time.time() - self.starttime)
        
class get_intermedium_loss_callback(get_callback):
    """
    Callback class for computing intermediate losses during training.

    Args:
        x: Input data for evaluation.
        y: Target data for evaluation.
        period: Frequency at which to compute intermediate losses (default: 'epoch').

    Attributes:
        testx: Input data for evaluation.
        testy: Target data for evaluation.
        model1: Model used for computing intermediate losses.

    Methods:
        on_train_begin: Called at the beginning of training.
        do_log: Compute and log intermediate losses.

    """
    def __init__(self, x, y, period='epoch'):
        super().__init__(period=period)
        self.testx = x
        self.testy = y

    def on_train_begin(self, logs=None):
        """
        Initialize the model used for computing intermediate losses.

        Args:
            logs: Dictionary of logs (default: None).

        """
        super().on_train_begin(logs)
        if isinstance(self.model, MPCNetwork):
            outputs = []
            for lossl, l in zip(self.model.lossblocks, self.model.blocks):
                outputs.append(lossl(l.output))
            self.model1 = keras.Model(self.model.input, outputs)
        else:
            outputs = []
            for l in self.model.layers[1:-1]:
                outputs.append(self.model.layers[-1](l.output))
            self.model1 = keras.Model(self.model.input, outputs)
        self.model1.compile(loss=self.model.loss)

    def do_log(self, logs):
        """
        Compute and log intermediate losses.

        Args:
            logs: Dictionary of logs.

        """
        # y_preds = self.model1.predict(self.testx, batch_size=100, verbose=0)
        # losses = []
        # for y_pred in y_preds:
        #     losses.append(self.model.compiled_loss(tf.convert_to_tensor(y_pred), tf.convert_to_tensor(self.testy)))
        # self.log.append(losses)
        self.log.append(self.model1.evaluate(self.testx, self.testy, batch_size=1000, verbose=0))
        
        
class Select_horizon_Callback(keras.callbacks.Callback):
    """
    Callback class for selecting horizon and adjusting learning rate during training.

    Args:
        dataset (tf.data.Dataset): The dataset used for training.
        period (int): The period at which to perform horizon selection and learning rate adjustment.
        mode (str): The mode of operation. Can be one of 'select', 'test', or 'adjust_lr'.
        init_select (bool): Whether to perform initial horizon selection.
        batch_num (int): The number of batches to use for horizon selection.
        delta (float): The threshold for selecting the horizon.
        df_name (str): The name of the file to save the results to.
        horizon_stride (list): A list of tuples representing different horizons and strides.
        schedule (dict): A dictionary representing a schedule of horizon changes.
        cost (float): The cost value.

    Methods:
        get_grad(horizon, stride, xs, y): Calculate the gradients for a given horizon and stride.
        get_horizon_grad(return_df=False): Calculate the gradients for different horizons and strides.
        adjust_learning_rate(horizon): Adjust the learning rate based on the horizon.
        on_train_begin(logs=None): Actions to be performed at the beginning of training.
        on_epoch_begin(epoch, logs=None): Actions to be performed at the beginning of each epoch.
        on_train_batch_begin(batch, logs=None): Actions to be performed at the beginning of each training batch.
    """
    def __init__(self, dataset, period=5, mode='select', init_select=True, batch_num=10,
                 delta=0.5, df_name=None, horizon_stride=None, schedule=None, cost=None):
        super().__init__()
        # Initialize the callback with the given parameters
        self.init_select = init_select
        self.period = period
        self.dataset = dataset
        self.batch_num = batch_num
        self.delta = delta
        self.schedule = schedule
        self.horizon_stride = horizon_stride
        if mode in ['select', 'test', 'adjust_lr']:
            self.mode = mode
        if self.mode == 'test':
            self.df_name = df_name
        self.cost = cost

    def get_grad(self, horizon, stride, xs, y):
        """
        Calculate the gradients for a given horizon and stride.

        Args:
            horizon (int): The horizon value.
            stride (int): The stride value.
            xs: The input data.
            y: The target data.

        Returns:
            grads: The calculated gradients.
        """
        # Calculate the gradients for a given horizon and stride
        s = 0
        lastblock = 0
        first_run = True
        grads = []
        while True:
            lastblock = min(self.T, s + horizon)
            last_run = not (lastblock < self.T)
            grad_weight, tmpxs, y_pred, loss = self.model.train_one_step(xs, y, self.model.blocks[s:lastblock],
                                                                          lossblock=self.model.lossblocks[
                                                                              lastblock - 1], stride=stride,
                                                                          first_run=first_run, last_run=last_run,
                                                                          training=False)
            grads = grads + [g for g, w in grad_weight]
            if last_run:
                break
            else:
                xs = tmpxs
                s = s + stride
            first_run = False
        return grads

    def get_horizon_grad(self, return_df=False):
        """
        Calculate the gradients for different horizons and strides.

        Args:
            return_df (bool): Whether to return a DataFrame.

        Returns:
            df: The DataFrame containing the calculated gradients (if return_df is True).
        """
        # Calculate the gradients for different horizons and strides
        model = self.model
        weights = model.get_weights()
        for l in model.lossblocks:
            l.trainable = False
        tweights = []
        for l in model.layers[:model.block_start] + model.blocks:
            tweights += l.trainable_weights
        tweights = [w.numpy().copy() for w in tweights]
        cos_sum = np.zeros(len(self.horizon_stride))
        ln_norm_ratio_sum = np.zeros(len(self.horizon_stride))
        if return_df:
            tweights2 = []
            for l in model.layers[:model.block_start]:
                tweights2 += l.trainable_weights
            tweights2 = [tweights2] + [b.trainable_weights for b in model.blocks]
            weight_len = list(map(lambda l: sum([w.shape.num_elements() for w in l]), tweights2))
            weight_ind = np.cumsum([0] + weight_len)[:-1]
            df = pd.DataFrame()
        for x, y in self.dataset.shuffle(100).take(self.batch_num):
            model.set_weights(weights)
            tdw = self.get_grad(self.T, 1, x, y)
            tdw = np.concatenate([g.numpy().flatten() for g in tdw])
            ntdw = np.linalg.norm(tdw)
            if return_df:
                tntdw = np.sqrt(np.add.reduceat(tdw ** 2, weight_ind))
            for i, (h, s) in enumerate(self.horizon_stride):
                dw = self.get_grad(h, s, x, y)
                dw = np.concatenate([g.numpy().flatten() for g in dw])
                ndw = np.linalg.norm(dw)
                cos_sum[i] += (tdw * dw).sum() / (ntdw + 1e-8) / ndw
                ln_norm_ratio_sum[i] += np.log(ndw / (ntdw + 1e-8))
                if return_df:
                    tndw = np.sqrt(np.add.reduceat(dw ** 2, weight_ind))
                    dfi = pd.Series(
                        dict(update_together=model.update_together, update_stride=model.update_stride,
                             horizon=h,stride=s,
                             ndw=ndw, proj_tdw=(tdw * dw).sum() / ntdw,
                             grad_ratio=np.linalg.norm(dw - tdw) / ntdw,
                             time_norm_ratio=(tntdw * tndw).sum() / ntdw / ndw,
                             cos_theta=(tdw * dw).sum() / ntdw / ndw,))
                    df = pd.concat([df, dfi.to_frame().T], ignore_index=True, )
        cos_mean = cos_sum / self.batch_num
        if len(self.horizon_stride) >= 4:
            cos_mean = np.poly1d(np.polyfit([h for h, s in self.horizon_stride], cos_mean, deg=3))(
                np.arange(self.T) + 1)
            self.cos_mean = pd.Series({h + 1: cos for h, cos in enumerate(cos_mean)})
            ln_norm_ratios = np.poly1d(np.polyfit([h for h, s in self.horizon_stride],
                                                  ln_norm_ratio_sum / self.batch_num, deg=3))(np.arange(self.T) + 1)
            self.norm_ratio = pd.Series({h + 1: np.exp(norm_ratio) for h, norm_ratio in enumerate(ln_norm_ratios)})
        else:
            self.cos_mean = pd.Series({h: cos_meanh for (h, s), cos_meanh in
                                       zip(self.horizon_stride, cos_mean)})
            self.norm_ratio = pd.Series({h: norm_ratioh for (h, s), norm_ratioh in
                                         zip(self.horizon_stride, np.exp(ln_norm_ratio_sum / self.batch_num))})
        model.set_weights(weights)
        for l in model.lossblocks:
            l.trainable = True
        if return_df:
            return df

    def adjust_learning_rate(self, horizon):
        """
        Adjust the learning rate based on the horizon.

        Args:
            horizon (int): The horizon value.
        """
        # Adjust the learning rate based on the horizon
        old_ch = self.ch
        self.ch = max(self.cos_mean[horizon], 0.1) / self.norm_ratio[horizon]
        self.model.optimizer.learning_rate.assign(self.model.optimizer.learning_rate / old_ch * self.ch)

    def on_train_begin(self, logs=None):
        """
        Actions to be performed at the beginning of training.
        """
        # Actions to be performed at the beginning of training
        self.T = len(self.model.blocks)
        self.model.lossT = 1.
        self.minloss = np.inf
        self.ch = 1.
        if self.horizon_stride is None:
            self.horizon_stride = [(1, 1), (int(self.T / 3.), 1), (int(self.T / 3. * 2), 1), (self.T, 1)]
            # self.horizon_stride=[(i+1,1) for i in range(self.T)]
        if self.mode == 'select':
            self.log = {'cos_mean': [], 'epoch': [], 'norm_ratio': []}
            self.horizon = self.model.horizon
            self.batch = 0
        elif self.mode == 'test':
            self.df = pd.DataFrame()
        elif self.mode == 'adjust_lr':
            self.horizon_stride = [(self.model.horizon, self.model.stride)]
            self.horizon = self.model.horizon
            self.batch = 0
        if self.init_select:
            if self.mode == 'test':
                tmpdf = self.get_horizon_grad(return_df=True)
                tmpdf['epoch'] = 0
                self.df = pd.concat([self.df, tmpdf], ignore_index=True)
                if self.df_name is not None:
                    self.df.to_csv(self.df_name)

    def on_epoch_begin(self, epoch, logs=None):
        """
        Actions to be performed at the beginning of each epoch.

        Args:
            epoch (int): The current epoch number.
        """
        # Actions to be performed at the beginning of each epoch
        if (epoch) % self.period == 0 and self.schedule == None:
            if epoch == 0 and not self.init_select:
                return
            if self.mode == 'select':
                self.get_horizon_grad()
                old_h = self.horizon
                self.horizon = self.objective()
                # self.horizon=self.cos_mean[(1-self.cos_mean**2-1e-4<=self.delta**2)&(self.cos_mean>0)].index.min()
                self.log['cos_mean'].append(self.cos_mean.copy())
                self.log['norm_ratio'].append(self.norm_ratio.copy())
                self.log['epoch'].append(epoch)
                self.model.horizon = self.horizon
                self.adjust_learning_rate(self.horizon)
                self.model.compile(self.model.optimizer, self.model.loss,
                                   [m for m in self.model.metrics if not isinstance(m, keras.metrics.Mean) or \
                                    isinstance(m, keras.metrics.MeanMetricWrapper)])
                self.model.make_train_function()
            elif self.mode == 'adjust_lr':
                self.get_horizon_grad()
                self.adjust_learning_rate(self.model.horizon)
        if self.schedule is not None and epoch in self.schedule.keys():
            self.get_horizon_grad()
            self.log['cos_mean'].append(self.cos_mean.copy())
            self.log['norm_ratio'].append(self.norm_ratio.copy())
            self.log['epoch'].append(epoch)
            self.model.horizon = self.schedule[epoch]
            self.adjust_learning_rate(self.schedule[epoch])
            self.model.compile(self.model.optimizer, self.model.loss,
                               [m for m in self.model.metrics if not isinstance(m, keras.metrics.Mean) or \
                                isinstance(m, keras.metrics.MeanMetricWrapper)])
            self.model.make_train_function()

    def on_train_batch_begin(self, batch, logs=None):
        """
        Actions to be performed at the beginning of each training batch.

        Args:
            batch (int): The current batch number.
        """
        # Actions to be performed at the beginning of each training batch
        if self.mode == 'select' or self.mode == 'adjust_lr':
            if self.batch < 100 - 1:
                self.batch += 1
            elif self.batch == 100 - 1:
                self.batch += 1
                self.get_horizon_grad()
                if self.mode == 'select':
                    old_h = self.horizon
                    self.horizon = self.objective()
                    # self.horizon=self.cos_mean[(1-self.cos_mean**2-1e-4<=self.delta**2)&(self.cos_mean>0)].index.min()
                    self.log['cos_mean'].append(self.cos_mean.copy())
                    self.log['norm_ratio'].append(self.norm_ratio.copy())
                    self.log['epoch'].append(-1)
                    self.model.horizon = self.horizon
                    self.adjust_learning_rate(self.horizon)
                    self.model.compile(self.model.optimizer, self.model.loss,
                                       [m for m in self.model.metrics if not isinstance(m, keras.metrics.Mean) or \
                                        isinstance(m, keras.metrics.MeanMetricWrapper)])
                    self.model.make_train_function()
                else:
                    self.adjust_learning_rate(self.model.horizon)

    def on_epoch_end(self, epoch, logs=None):
        # Actions to be performed at the end of each epoch
        logs = logs or {}
        logs["horizon"] = self.model.horizon
        if hasattr(self, 'cos_mean'):
            logs['cos1'] = self.cos_mean[self.model.horizon - 1] if self.model.horizon - 1 in self.cos_mean.keys() else -2
            logs['cos'] = self.cos_mean[self.model.horizon]
            logs['ch'] = self.ch
            logs['norm_ratio'] = self.norm_ratio[self.model.horizon]
        if (epoch + 1) % self.period == 0:
            if self.mode == 'test':
                tmpdf = self.get_horizon_grad(return_df=True)
                tmpdf['epoch'] = epoch
                self.df = pd.concat([self.df, tmpdf], ignore_index=True)
                if self.df_name is not None:
                    self.df.to_csv(self.df_name)
        if epoch == 0:
            self.loss0 = self.model.metrics[0].result()
        self.minloss = min(self.minloss, self.model.metrics[0].result())
        self.model.lossT = self.minloss / self.loss0

    def on_train_end(self, logs=None):
        # Actions to be performed at the end of training
        if self.mode == 'select':
            self.log1 = pd.DataFrame(self.log['norm_ratio'], columns=np.arange(self.T) + 1,
                                     index=self.log['epoch']).reset_index(names='epoch')
            self.log1 = self.log1.melt(id_vars='epoch', value_vars=np.arange(1, self.T + 1), value_name='norm_ratio',
                                       var_name='layer')
            self.log = pd.DataFrame(self.log['cos_mean'], columns=np.arange(self.T) + 1,
                                    index=self.log['epoch']).reset_index(names='epoch')
            self.log = self.log.melt(id_vars='epoch', value_vars=np.arange(1, self.T + 1), value_name='cos_theta',
                                     var_name='layer')
            self.log = pd.merge(self.log, self.log1, on=['epoch', 'layer'])

    def loss_estimate(self):
        # Estimate the loss based on the cosine mean
        return -self.cos_mean ** 2 * np.sign(self.cos_mean)

    def objective2(self):
        # Determine the objective based on the cosine mean and memory
        sin = np.sqrt(1 - self.cos_mean ** 2 + 1e-4)
        if self.cost=='linear':
            cost = self.loss_estimate() + self.oblambda * self.memory
        else:
            cost = self.loss_estimate() + self.oblambda * self.cost(self.memory)
        return int(cost[self.cos_mean[self.cos_mean > 0].index].idxmin())

    def objective1(self):
        # Determine the objective based on the cosine mean and delta
        accept_horizon = self.cos_mean[(1 - self.cos_mean ** 2 - 1e-4 <= self.delta ** 2) & (self.cos_mean > 0)].index
        if self.cost is None:
            return int(accept_horizon.min())
        else:
            cost = self.cost(self.memory)[accept_horizon - 1]
            min_cost = min(cost)
            return int(accept_horizon[cost == min_cost].max())