import numpy as np
import copy

from numpy import linalg


class ExperienceReplayOja:
    def __init__(self, seed):
        np.random.seed(seed)

    def run_simulation(self, data, markov_chain, buffer_size, buffer_drop_number, lr_multiplier=0.1, lr_decay=False):
        all_sine_squared_errors = []
        num_repetitions = len(data)
        counter = 0
        for r in range(num_repetitions):
            sine_squared_errors = []
            if lr_decay:
                eta_init = lr_multiplier * np.log(len(data[0])) / len(data[0])
            else:
                eta_init = lr_multiplier * np.log(len(data[0])) / len(data[0])
            eta = copy.deepcopy(eta_init)
            w = np.random.randn(markov_chain.num_dimensions)
            w /= linalg.norm(w)
            counter = 0
            indices = []
            for t in range(0, len(data[r]), buffer_size):
                for j in range(buffer_size - buffer_drop_number):
                    index = np.random.randint(low=buffer_drop_number, high=buffer_size)
                    if t + index < len(data[r]):
                        At = data[r][t + index][2]
                        w += eta * np.matmul(np.outer(At, At), w)
                        w /= linalg.norm(w)
                        counter += 1
                        sine_squared_errors.append(1 - (np.dot(w, markov_chain.largest_eigenvector)) ** 2)
                        indices.append(t + buffer_drop_number + j)
                        if lr_decay:
                            eta = eta_init / (1.0 + counter)
            all_sine_squared_errors.append(sine_squared_errors)
        all_sine_squared_errors = np.array(all_sine_squared_errors)
        mean_sine_squared_errors = np.mean(all_sine_squared_errors, axis=0)
        assert (len(mean_sine_squared_errors) == counter)
        return mean_sine_squared_errors, indices
