import numpy as np
import copy

from numpy import linalg


class VanillaOja:
    def __init__(self):
        pass

    def run_simulation(self, data, markov_chain, lr_multiplier, lr_decay, w_init, is_iid):
        all_sine_squared_errors = []
        num_repetitions = len(data)
        if(not(is_iid)):
            beta = 5/markov_chain.transition_eigengap
        else:
            beta = 5
        alpha = 10/markov_chain.cov_eigengap
        for r in range(num_repetitions):
            sine_squared_errors = []
            if lr_decay:
                eta_init = alpha
            else:
                eta_init = lr_multiplier * np.log(len(data[0])) / len(data[0])
            eta = copy.deepcopy(eta_init)
            # w = np.random.randn(markov_chain.num_dimensions)
            # w /= linalg.norm(w)
            w = copy.deepcopy(w_init[r])
           #      w = np.array([-0.04938343, -0.85770541,  0.04689626,  0.03063524,  0.09549688,
           # -0.30025125,  0.14636187, -0.06387728, -0.35070437,  0.10487514])
           #      w = np.array([ 0.21388957, -0.36714546, -0.11556121,  0.2216735 ,  0.16002821,
           # -0.18553774,  0.34332446,  0.00363921,  0.16681531,  0.16685758,
           #  0.02094128, -0.15440694,  0.16384269,  0.26598201,  0.06726319,
           #  0.15962049,  0.37779112,  0.39899441, -0.26105003, -0.03783128])
            indices = []
            for t in range(len(data[r])):
                sine_squared_errors.append(1 - (np.dot(w, markov_chain.largest_eigenvector)) ** 2)
                # if(t == 0):
                #     print("Initial error : ", sine_squared_errors)
                if lr_decay:
                    eta = eta_init/(beta + t)
                At = data[r][t][2]
                w += eta * np.matmul(np.outer(At, At), w)
                w /= linalg.norm(w)
                indices.append(t)
            all_sine_squared_errors.append(sine_squared_errors)
        all_sine_squared_errors = np.array(all_sine_squared_errors)
        mean_sine_squared_errors = np.mean(all_sine_squared_errors, axis=0)
        std_sine_squared_errors = np.std(all_sine_squared_errors, axis=0)
        assert(len(mean_sine_squared_errors) == len(data[0]))
        return mean_sine_squared_errors,std_sine_squared_errors,indices
