import numpy as np
import pickle
import time
import scipy.io


def run_tos(problem, gamma_0, save_file, n_iterations=200, algorithm="adaptive", trial=False, verbose=True, TrueImage=[]):
    """
    problem: the TOS problem, should have proj_g, proj_h, obj, df, and gamma_0
    n_iterations: number of iterations
    algorithm: one of "adaptive", "smooth", "nonsmooth"
    verbose (True/False): print function values as running
    """

    # get functions necessary for TOS
    proj_g = problem.proj_g
    proj_h = problem.proj_h
    obj = problem.obj
    df = problem.df

    # random initialization
    y_t = np.zeros((493, 517)) + 0.5

    # values to store
    u_sum = 0
    x_av, z_av = 0, 0
    xz_gap, xz_av_gap = [], []
    x_obj, z_obj, x_av_obj, z_av_obj = [], [], [], []
    y_norm = []
    iter_time = []
    if len(TrueImage) != 0:
        x_psnr = []
        z_psnr = []
        x_av_psnr = []
        z_av_psnr = []

    if trial:  # not running for real, just to get gamma_0
        max_norm_subgrad = 0
        max_norm_y = 0

    time_total = 0.0
    for t in range(1, n_iterations + 1):

        time_start = time.time()

        # tos update part 1
        z_t = proj_g(y_t)
        u_t = df(z_t)

        # step size update
        if algorithm == "adaptive":
            u_sum += np.sum(u_t * u_t)
            gamma_t = gamma_0 / np.sqrt(u_sum)
        elif algorithm == "smooth":
            gamma_t = gamma_0
        elif algorithm == "nonsmooth":
            gamma_t = gamma_0 / np.sqrt(t)

        # tos update part 2
        x_t = proj_h(2 * z_t - y_t - gamma_t * u_t)
        y_t = y_t - z_t + x_t

        # update average iterates
        x_av = (1 - 1 / t) * x_av + (1 / t) * x_t
        z_av = (1 - 1 / t) * z_av + (1 / t) * z_t

        # update timer
        time_total += time.time() - time_start
        iter_time.append(time_total)

        # |x_t-z_t|
        xz_gap.append(np.linalg.norm(x_t - z_t))
        xz_av_gap.append(np.linalg.norm(x_av - z_av))

        # |y_t|
        y_norm.append(np.linalg.norm(y_t))

        # keep track of objective values
        x_obj.append(obj(x_t))
        z_obj.append(obj(z_t))
        x_av_obj.append(obj(x_av))
        z_av_obj.append(obj(z_av))
        
        # keep track of the psnr
        if len(TrueImage) != 0:
            x_psnr.append(eval_psnr(z_t, TrueImage, 1.0))
            z_psnr.append(eval_psnr(z_t, TrueImage, 1.0))
            x_av_psnr.append(eval_psnr(z_av, TrueImage, 1.0))
            z_av_psnr.append(eval_psnr(z_av, TrueImage, 1.0))

        # keep track of max norms for step size heuristic
        if trial:
            max_norm_subgrad = max(max_norm_subgrad, np.linalg.norm(u_t))
            max_norm_y = max(max_norm_y, np.linalg.norm(y_t))

        if verbose and t % 50 == 0:
            print(f"Iteration {t}: f(z)={z_obj[-1]}, ||x-z||={xz_gap[-1]}")

    if trial:
        return max_norm_y / max_norm_subgrad

    data = {}
    # final iterates
    data["x_final"] = np.array(x_t)
    data["z_final"] = np.array(z_t)
    data["x_av_final"] = np.array(x_av)
    data["z_av_final"] = np.array(z_av)

    # |x-z|
    data["xz_gap"] = np.array(xz_gap)
    data["xz_avg_gap"] = np.array(xz_av_gap)

    # f(x), f(z)
    data["x_av_obj"] = np.array(x_av_obj)
    data["z_av_obj"] = np.array(z_av_obj)
    data["x_obj"] = np.array(x_obj)
    data["z_obj"] = np.array(z_obj)

    # |y|
    data["y_norm"] = np.array(y_norm)

    # time
    data["time"] = np.array(iter_time)
    
    # psnr
    if len(TrueImage) != 0:
        data["x_psnr"] = np.array(x_psnr)
        data["z_psnr"] = np.array(z_psnr)
        data["x_av_psnr"] = np.array(x_av_psnr)
        data["z_av_psnr"] = np.array(z_av_psnr)

    with open(save_file, "wb") as f:
        pickle.dump(data, f)

    scipy.io.savemat(save_file + ".mat", mdict=data)

    return data


def eval_psnr(img, ref, peakval=255):
    mse = np.mean((img - ref) ** 2)
    if mse == 0:
        return 100
    return 10 * np.log10(peakval**2 / mse)