from solvers.Solver import Solver

from tqdm import tqdm

import numpy as np


class HeatSolver(Solver):
    name = 'heat'

    def __init__(self, settings):
        super().__init__(settings)
        self.stacked_ops = None

    def generate_operator(self):
        # b = self.a_matrix(self.seed + 1).flatten()
        # a = np.diag(b)
        np.random.seed(self.op_seed)
        a = np.abs(np.diag(np.random.randn(self.grid_dimension ** 2)))
        # a = np.diag(np.ones(self.grid_dimension ** 2))  # For checking the non-observable case

        laplacian = self.generate_laplacian(a)
        return np.identity(self.grid_dimension ** 2) + self.gamma * laplacian


    def calculate_fields_alt(self):
        self.op = self.generate_operator()
        self.generate_operator_powers()
        self.op_powers = self.tensor_wrapper(self.op_powers)

        if self.stacked_ops is None:
            self.stacked_ops = np.stack(self.op_powers)
        broadcast_initial = np.broadcast_to(np.expand_dims(self.u[0], axis=0), (self.timesteps, self.grid_dimension ** 2))
        self.u = self.mult(self.stacked_ops, broadcast_initial[..., np.newaxis]).squeeze(-1)

        self.m_u = self.calculate_masked_u_field(self.u)

    def calculate_fields(self):
        if self.op is None:
            self.op = self.generate_operator()

        for i in range(1, self.timesteps):
            self.u[i, :] = self.mult(self.op, self.u[i - 1, :])

        self.m_u = self.calculate_masked_u_field(self.u)
