import initial

import numpy as np

from tqdm import tqdm


class Solver:
    def __init__(self, settings):
        self.grid_dimension = settings['LD']
        self.mask_dimension = settings['c']
        self.delta_x = settings['delta_x']  # Will almost always be 1
        self.alpha = settings['alpha']
        self.time_factor = settings['time_factor']  # Needs to be <= 1 for numerical stability
        self.timesteps = settings['T']
        self.seed = settings['seed']
        self.law = settings['law']
        self.op_seed = settings['op_seed']
        self.settings = settings

        self.delta_t = self.time_factor * (self.delta_x ** 2) / (4 * self.alpha)
        self.gamma = (self.alpha * self.delta_t) / (self.delta_x ** 2)

        assert self.grid_dimension % self.mask_dimension == 0
        self.output_dimension = self.grid_dimension // self.mask_dimension

        self.mult = np.matmul
        self.tensor_wrapper = np.array
        self.transpose = np.transpose

        self.u = np.zeros((self.timesteps, self.grid_dimension ** 2))
        self.u_dot = np.zeros((self.timesteps, self.grid_dimension ** 2))
        self.m_u = np.zeros((self.timesteps, self.output_dimension ** 2))

        self.mask = self.generate_mask_matrix()
        self.op = None
        self.op_powers = None

        x = self.generate_initial_conditions(self.seed)
        self.u[0, :] = Solver.normalise(x)

    def a_matrix(self, seed, offset=1):
        z = initial.sim_rf(law=self.law, settg=self.settings, seed=seed)['field']
        return np.exp(z * offset)

    def generate_initial_conditions(self, seed, offset=1):
        settings = self.settings.copy()
        settings['nu'] = self.settings['init_nu']
        settings['lambda'] = self.settings['init_lambda']
        z = initial.sim_rf(law=self.law, settg=settings, seed=seed)['field']
        return z.flatten() * offset

    def generate_more_data(self):
        x = self.generate_initial_conditions(self.seed)
        new_u = np.zeros_like(self.u)
        new_u[0, :] = Solver.normalise(x)

    def calculate_masked_u_field(self, u):
        x = self.mult(self.mask, self.transpose(u))
        return self.transpose(x)

    def generate_mask_matrix(self):
        gd = self.grid_dimension
        c = self.mask_dimension
        q = self.output_dimension

        return generate_mask_matrix(gd, c, q)

    def del_matrix(self, direction, parity):
        assert isinstance(direction, bool)
        assert parity in (-1, +1)

        gd = self.grid_dimension
        matrix = -parity * np.identity(gd ** 2)
        for x in range(gd):
            for y in range(gd):
                x_offset = (x + parity) % gd
                y_offset = (y + parity) % gd

                k = matrix_to_vector(x, y, gd)
                if direction:
                    k_offset = matrix_to_vector(x_offset, y, gd)
                else:
                    k_offset = matrix_to_vector(x, y_offset, gd)

                matrix[k, k_offset] = parity
        return matrix

    def generate_laplacian(self, a):  # This is for 2d
        del_plus_x = self.del_matrix(True, +1)
        del_minus_x = self.del_matrix(True, -1)
        del_plus_y = self.del_matrix(False, +1)
        del_minus_y = self.del_matrix(False, -1)
        x = del_plus_x @ a @ del_minus_x
        y = del_plus_y @ a @ del_minus_y
        return x + y

    def generate_operator_powers(self):
        if self.op_powers is not None:
            return
        self.op_powers = np.zeros((self.timesteps, *self.op.shape))
        with tqdm(total=self.timesteps, unit='powers') as bar:
            self.op_powers[0, :] = np.identity(self.op.shape[0])  # Set A^0 = I_{L^2}
            bar.update(1)
            for i in range(1, self.timesteps):  # Will generate n matrices, from A^0 to A^{t-1}.
                self.op_powers[i, :] = self.op_powers[i - 1] @ self.op
                bar.update(1) # Can't parallelize further - each loop depends on the previous.

    @staticmethod
    def normalise(x):
        return 2 * ((x - x.min()) / (x.max() - x.min())) - 1

    def calculate_fields(self):  # Child classes should implement these methods
        raise NotImplementedError

    def generate_operator(self):
        raise NotImplementedError


def matrix_to_vector(i, j, L):
    # Convert from matrix labelling (L x L) to vector (L^2)
    return L * (i % L) + (j % L)

def generate_mask_matrix(gd, c, q):
    mask = np.zeros((q ** 2, gd ** 2), dtype=int)

    # Loop over the q x q possible masks
    for n in range(q):
        for m in range(q):
            # Find the vector index of the mask
            k = matrix_to_vector(n, m, q)

            # Define a square of size c x c
            x_range = range(n * c, (n + 1) * c)
            y_range = range(m * c, (m + 1) * c)

            # Convert from matrix notation to vector, set relevant entries to 1
            for i in x_range:
                for j in y_range:
                    ind = matrix_to_vector(i, j, gd)
                    mask[k, ind] = 1
    return mask
