import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from typing import Tuple
from src import utils, iga
from time import time


class Test:
    def __init__(self, test_config):
        self._pde = test_config.pde
        self._output_filename = test_config.output_filename
        self._output_figurename = test_config.output_figurename
        self._output_name = test_config.output_name
        self._plot_type = test_config.plot_type
        self._experiment_dict = {
            "number": [],
            "l2": [],
            "relative_l2": [],
            "beta": [],
            "time": []
        }

    def plot_results(self):
        match self._plot_type:
            case "error":
                plt.semilogy(self._experiment_dict["number"], self._experiment_dict["l2"], marker=".", label="rmse")
                plt.semilogy(self._experiment_dict["number"], self._experiment_dict["relative_l2"], marker=".", label="relative L2 error")
                plt.title(self._output_name)
                plt.xlabel("number of functions")
                plt.ylabel("error")
                plt.legend()
                plt.savefig(self._output_figurename)
            case "beta":
                plt.loglog(self._experiment_dict["beta"], self._experiment_dict["l2"], marker=".", label="rmse")
                plt.loglog(self._experiment_dict["beta"], self._experiment_dict["relative_l2"], marker=".", label="relative L2 error")
                plt.title(self._output_name)
                plt.xlabel("beta")
                plt.ylabel("error")
                plt.legend()
                plt.savefig(self._output_figurename)
            case _:
                print(f"Unknown plot type: {self._plot_type}")
                sys.exit(1)

    def _run_single_test(
            self,
            beta: float | None,
            n: int,
            k: int
        ) -> Tuple[int, np.ndarray, np.ndarray]:
        match self._pde:
            case "advection_periodic":
                self._experiment_dict["beta"].append(beta)
                t0 = time()
                solver = iga.AdvectionPeriodic(beta=beta, n=n, k=k)
                n_functions, sol = solver.solve()
                t1 = time()
                self._experiment_dict["time"].append(t1-t0)
                gt = solver.compute_gt()
            case "euler_bernoulli":
                t0 = time()
                solver = iga.EulerBernoulli(n=n, k=k)
                n_functions, sol = solver.solve()
                t1 = time()
                self._experiment_dict["time"].append(t1-t0)
                gt = solver.compute_gt()
            case "burger_dirichlet":
                t0 = time()
                solver = iga.BurgerDirichlet(n=n, k=k)
                n_functions, sol = solver.solve()
                t1 = time()
                self._experiment_dict["time"].append(t1-t0)
                gt = solver.compute_gt()
            case _:
                print(f"Unknown PDE: {self._pde}")
                sys.exit(1)
        return n_functions, sol, gt
    
    def _compute_error(
            self,
            sol: np.ndarray,
            gt: np.ndarray
        ) -> Tuple[float, float]:

        # root mean squared error
        rmse = np.sqrt(np.linalg.norm(sol - gt)**2 / (sol.shape[0] * sol.shape[1]))
        # relative l2 error
        relative_error = np.linalg.norm(sol - gt) / np.linalg.norm(gt)

        return rmse, relative_error
    

class SingleTest(Test):
    def __init__(self, test_config):
        super(SingleTest, self).__init__(test_config)
        if test_config.list_beta is None:
            self._beta = None
        else:
            self._beta = test_config.list_beta[-1]
        self._n = test_config.list_n[-1]
        self._k = test_config.list_k[-1]

    def plot_result(self):
        match self._pde:
            case "advection_periodic":
                x_lim = [0, 1, 0, 2*np.pi]
                fontsize = 22
                aspect = 0.07
            case "euler_bernoulli":
                x_lim = [0, 1, 0, np.pi]
                fontsize = 24
                aspect = 0.09
            case "burger_dirichlet":
                x_lim = [0, 1, -1, 1]
                fontsize = 22
                aspect = 0.2
            case _:
                print(f"Unknown PDE: {self._pde}")
                sys.exit(1)

        error_filename = "outputs/" + self._output_name + "_error.pdf"

        # visualize the solution
        fig, ax = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)
        img = ax.imshow(abs(self._sol-self._gt), extent=x_lim, origin='lower', aspect=aspect, cmap="jet")
        cb = fig.colorbar(img, ax=ax, location='bottom',aspect=20)

        # Set the formatter for the colorbar to scientific notation
        tick_locator = ticker.MaxNLocator(nbins=3)
        cb.locator = tick_locator
        cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.1e}')
        cb.update_ticks()

        # Set the font size of the colorbar labels
        cb.ax.tick_params(labelsize=fontsize)  # Change 12 to your desired font size
        ax.set_xlabel('t', fontsize=fontsize)
        ax.set_ylabel('x', fontsize=fontsize)
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
        plt.tick_params(axis='both', labelsize=fontsize)
        plt.savefig(error_filename)
        plt.clf()


    def run_test(self):
        print("---------runtime test---------")
        n_functions, sol, gt = self._run_single_test(
            beta=self._beta,
            n = self._n,
            k = self._k
        )
        self._gt = gt
        self._sol = sol
        
        self._experiment_dict["number"].append(n_functions)
        error_l2, error_relative_l2 = self._compute_error(sol=sol, gt=gt)
        self._experiment_dict["l2"].append(error_l2)
        self._experiment_dict["relative_l2"].append(error_relative_l2)

        print(f"number of functions: {n_functions}")
        print(f"l2 error: {error_l2}")
        print(f"relative l2 error: {error_relative_l2}")
        solve_time = self._experiment_dict["time"][-1]
        print(f"solve time: {(solve_time):.05f}")
        utils.save_json(filename=self._output_filename, content=self._experiment_dict)
        

class MultipleTest(Test):
    """This class defines multiple tests for a PDE

    Methods:
    --------
    run_tests():
        Solve PDE multiple times
    plot_results():
        Make convergence plot
    """

    def __init__(self, test_config):
        super(MultipleTest, self).__init__(test_config)
        if test_config.list_beta is None:
            self._n_tests = min(len(test_config.list_n), len(test_config.list_k))
            self._list_beta = [None] * self._n_tests
        else:
            self._n_tests = min(len(test_config.list_n), len(test_config.list_k), len(test_config.list_beta))
            self._list_beta = test_config.list_beta[:self._n_tests]
        
        self._list_n = test_config.list_n[:self._n_tests]
        self._list_k = test_config.list_k[:self._n_tests]

    def run_tests(self):

        for i in range(self._n_tests):
            print(f"---------test:{i+1}/{self._n_tests}---------")
            n_functions, sol, gt = self._run_single_test(
                beta=self._list_beta[i],
                n=self._list_n[i],
                k=self._list_k[i]
            )
            self._experiment_dict["number"].append(n_functions)

            error_l2, error_relative_l2 = self._compute_error(sol=sol, gt=gt)
            self._experiment_dict["l2"].append(error_l2)
            self._experiment_dict["relative_l2"].append(error_relative_l2)

            print(f"number of functions: {n_functions}")
            print(f"l2 error: {error_l2}")
            print(f"relative l2 error: {error_relative_l2}")
            solve_time = self._experiment_dict["time"][-1]
            print(f"solve time: {(solve_time):.05f}")

        
        utils.save_json(filename=self._output_filename, content=self._experiment_dict)
                
                
    
    

        
    