# from runners.ncsn_runner import RunningAverageMeter
import argparse
import os
import math
import matplotlib.pyplot as plt


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99, save_seq=True):
        self.momentum = momentum
        self.save_seq = save_seq
        if self.save_seq:
            self.vals, self.steps = [], []
        self.reset()

    def reset(self):
        self.val, self.avg = None, 0

    def update(self, val, step=None):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val
        if self.save_seq:
            self.vals.append(val)
            if step is not None:
                self.steps.append(step)


def savefig(path, bbox_inches='tight', pad_inches=0.1):
        try:
            plt.savefig(path, bbox_inches=bbox_inches, pad_inches=pad_inches)
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except:
            print(sys.exc_info()[0])


def main(log_path):
    with open(os.path.join(log_path, 'stdout.txt'), 'r') as f:
        lines = f.readlines()

    mses, psnrs, ssims, lpipss, fvds = RunningAverageMeter(), RunningAverageMeter(), RunningAverageMeter(), RunningAverageMeter(), RunningAverageMeter()
    best_mse = {'ckpt': -1, 'mse': math.inf, 'psnr': -math.inf, 'ssim': -math.inf, 'lpips': math.inf, 'fvd': math.inf}
    best_psnr = {'ckpt': -1, 'mse': math.inf, 'psnr': -math.inf, 'ssim': -math.inf, 'lpips': math.inf, 'fvd': math.inf}
    best_ssim = {'ckpt': -1, 'mse': math.inf, 'psnr': -math.inf, 'ssim': -math.inf, 'lpips': math.inf, 'fvd': math.inf}
    best_lpips = {'ckpt': -1, 'mse': math.inf, 'psnr': -math.inf, 'ssim': -math.inf, 'lpips': math.inf, 'fvd': math.inf}
    best_fvd = {'ckpt': -1, 'mse': math.inf, 'psnr': -math.inf, 'ssim': -math.inf, 'lpips': math.inf, 'fvd': math.inf}
    for line in lines:
        if 'fvd:' in line and 'Best' not in line:
            ckpt = int(line.strip().split(', ')[0].split(':')[-1])
            mse = float(line.strip().split(', ')[2].split(':')[-1])
            psnr = float(line.strip().split(', ')[3].split(':')[-1])
            ssim = float(line.strip().split(', ')[4].split(':')[-1])
            lpips = float(line.strip().split(', ')[5].split(':')[-1])
            fvd = float(line.strip().split(', ')[6].split(':')[-1])
            vid_metrics = {'ckpt': ckpt, 'mse': mse, 'psnr': psnr, 'ssim': ssim, 'lpips': lpips, 'fvd': fvd}
            mses.update(vid_metrics['mse'], ckpt)
            psnrs.update(vid_metrics['psnr'])
            ssims.update(vid_metrics['ssim'])
            lpipss.update(vid_metrics['lpips'])
            if vid_metrics['mse'] < best_mse['mse']:
                best_mse = vid_metrics
            if vid_metrics['psnr'] > best_psnr['psnr']:
                best_psnr = vid_metrics
            if vid_metrics['ssim'] > best_ssim['ssim']:
                best_ssim = vid_metrics
            if vid_metrics['lpips'] < best_lpips['lpips']:
                best_lpips = vid_metrics
            fvds.update(vid_metrics['fvd'])
            if vid_metrics['fvd'] < best_fvd['fvd']:
                best_fvd = vid_metrics

    # MSE
    plt.plot(mses.steps, mses.vals)
    if best_mse['ckpt'] > -1:
        plt.scatter(best_mse['ckpt'], mses.vals[mses.steps.index(best_mse['ckpt'])], color='k')
        plt.text(best_mse['ckpt'], mses.vals[mses.steps.index(best_mse['ckpt'])], f"{mses.vals[mses.steps.index(best_mse['ckpt'])]:.04f}\n{best_mse['ckpt']}", c='r')

    plt.xlabel("Steps")
    plt.ylabel("MSE")
    plt.grid(True)
    plt.grid(visible=True, which='minor', axis='y', linestyle='--')
    # plt.legend(loc='upper right')
    savefig(os.path.join(log_path, 'mse.png'))
    plt.yscale("log")
    savefig(os.path.join(log_path, 'mse_log.png'))
    plt.clf()
    plt.close()
    # PSNR
    plt.plot(mses.steps, psnrs.vals)
    if best_psnr['ckpt'] > -1:
        plt.scatter(best_psnr['ckpt'], psnrs.vals[mses.steps.index(best_psnr['ckpt'])], color='k')
        plt.text(best_psnr['ckpt'], psnrs.vals[mses.steps.index(best_psnr['ckpt'])], f"{psnrs.vals[mses.steps.index(best_psnr['ckpt'])]:.04f}\n{best_psnr['ckpt']}", c='r')

    plt.xlabel("Steps")
    plt.ylabel("PSNR")
    plt.grid(True)
    plt.grid(visible=True, which='minor', axis='y', linestyle='--')
    # plt.legend(loc='upper right')
    savefig(os.path.join(log_path, 'psnr.png'))
    plt.yscale("log")
    savefig(os.path.join(log_path, 'psnr_log.png'))
    plt.clf()
    plt.close()
    # SSIM
    plt.plot(mses.steps, ssims.vals)
    if best_ssim['ckpt'] > -1:
        plt.scatter(best_ssim['ckpt'], ssims.vals[mses.steps.index(best_ssim['ckpt'])], color='k')
        plt.text(best_ssim['ckpt'], ssims.vals[mses.steps.index(best_ssim['ckpt'])], f"{ssims.vals[mses.steps.index(best_ssim['ckpt'])]:.04f}\n{best_ssim['ckpt']}", c='r')

    plt.xlabel("Steps")
    plt.ylabel("SSIM")
    plt.grid(True)
    plt.grid(visible=True, which='minor', axis='y', linestyle='--')
    # plt.legend(loc='upper right')
    savefig(os.path.join(log_path, 'ssim.png'))
    plt.yscale("log")
    savefig(os.path.join(log_path, 'ssim_log.png'))
    plt.clf()
    plt.close()
    # LPIPS
    plt.plot(mses.steps, lpipss.vals)
    if best_lpips['ckpt'] > -1:
        plt.scatter(best_lpips['ckpt'], lpipss.vals[mses.steps.index(best_lpips['ckpt'])], color='k')
        plt.text(best_lpips['ckpt'], lpipss.vals[mses.steps.index(best_lpips['ckpt'])], f"{lpipss.vals[mses.steps.index(best_lpips['ckpt'])]:.04f}\n{best_lpips['ckpt']}", c='r')

    plt.xlabel("Steps")
    plt.ylabel("LPIPS")
    plt.grid(True)
    plt.grid(visible=True, which='minor', axis='y', linestyle='--')
    # plt.legend(loc='upper right')
    savefig(os.path.join(log_path, 'lpips.png'))
    plt.yscale("log")
    savefig(os.path.join(log_path, 'lpips_log.png'))
    plt.clf()
    plt.close()
    # FVD
    plt.plot(mses.steps, fvds.vals)
    if best_fvd['ckpt'] > -1:
        plt.scatter(best_fvd['ckpt'], fvds.vals[mses.steps.index(best_fvd['ckpt'])], color='k')
        plt.text(best_fvd['ckpt'], fvds.vals[mses.steps.index(best_fvd['ckpt'])], f"{fvds.vals[mses.steps.index(best_fvd['ckpt'])]:.04f}\n{best_fvd['ckpt']}", c='r')

    plt.xlabel("Steps")
    plt.ylabel("FVD")
    plt.grid(True)
    plt.grid(visible=True, which='minor', axis='y', linestyle='--')
    # plt.legend(loc='upper right')
    savefig(os.path.join(log_path, 'fvd.png'))
    plt.yscale("log")
    savefig(os.path.join(log_path, 'fvd_log.png'))
    plt.clf()
    plt.close()


if __name__ == '__main__':
    # log_path = '/path/to/SMMNIST/DDPM_big_c5t5_SPADE/logs'
    parser = argparse.ArgumentParser(description="correct FVD plots")
    parser.add_argument('--log_path', type=str, required=True, help='Path to the logs dir')
    args = parser.parse_args()
    main(args.log_path)

