import numpy as np
import copy

from numpy import linalg


class DataDropOja:
    def __init__(self):
        pass

    def run_simulation(self, data, markov_chain, drop_number, lr_multiplier, lr_decay, w_init):
        all_sine_squared_errors = []
        num_repetitions = len(data)
        counter = 0
        beta = 5 / (markov_chain.transition_eigengap * drop_number)
        # beta = 5*drop_number/(markov_chain.transition_eigengap)
        alpha = 5/markov_chain.cov_eigengap
        for r in range(num_repetitions):
            sine_squared_errors = []
            if lr_decay:
                eta_init = alpha
            else:
                eta_init = lr_multiplier * np.log(len(data[0])/drop_number) / (len(data[0])/drop_number)
            eta = copy.deepcopy(eta_init)
            # w = np.random.randn(markov_chain.num_dimensions)
            # w /= linalg.norm(w)
            w = copy.deepcopy(w_init[r])
            counter = 0
            indices = []
            for t in range(0, len(data[r]), drop_number):
                sine_squared_errors.append(1 - (np.dot(w, markov_chain.largest_eigenvector)) ** 2)
                if lr_decay:
                    eta = eta_init / (beta + counter)
                At = data[r][t][2]
                w += eta * np.matmul(np.outer(At, At), w)
                w /= linalg.norm(w)
                counter += 1
                indices.append(t)
            all_sine_squared_errors.append(sine_squared_errors)
        all_sine_squared_errors = np.array(all_sine_squared_errors)
        mean_sine_squared_errors = np.mean(all_sine_squared_errors, axis=0)
        std_sine_squared_errors = np.std(all_sine_squared_errors, axis=0)
        assert (len(mean_sine_squared_errors) == counter)
        return mean_sine_squared_errors,std_sine_squared_errors,indices
