#!/usr/bin/perl -w
# vim: tabstop=4 shiftwidth=4 expandtab nosmarttab
#
# Generate & display convergence charts from vw progress outputs
# Requires R to generate the charts
#
use Getopt::Std;
use Scalar::Util qw(looks_like_number);
use vars qw ($opt_d $opt_x $opt_y $opt_t
             $opt_w $opt_h $opt_a $opt_q $opt_p $opt_Q
             $opt_o $opt_v $opt_l $opt_s
);

my $TmpImgFile = '/tmp/vw-convergence.png';
my $TmpRFile = '/tmp/vw-convergence.R';
my $LossStr;
my $DisplayProg = 'display';
my ($DefaultWidth, $DefaultHeight) = (800, 600);
my $DefaultPointSize = 1.2;

sub find_image_viewer() {
    # If your favorite OS image viewer isn't here, please add it
    foreach my $iv (
        $DisplayProg, 'display', 'gwenview', 'kuickshow',   # Linux
        'xee', 'preview'                                    # OS-X alternatives
    ) {
        my $path = `which $iv`; chomp($path);
        if (-x $path) {
            $DisplayProg = $iv;
            v("find_image_viewer: found executable '%s' @ '%s'\n", $iv, $path);
            return;
        }
    }
    # if we get here without finding anything, we can't display
    warn "$0: find_image_viewer: couldn't find an image-viewer in this env\n";
}

sub v(@) {
    return unless $opt_v;
    if (@_ == 1) {
        print STDERR @_;
    } else {
        printf STDERR @_;
    }
}

sub usage(@) {
    print STDERR @_, "\n" if (@_);

    die "Usage: $0 [options] [input_files...]

    If input_file is missing, input is read from stdin.
    Assuming inputs are one or more 'vw' progress reports
    Requires R to generate the chart

    Options:
        -v      Verbose/debug
        -d      No display (only creates image file)
        -a      Convert from squared-loss X to abs-loss (apply sqrt(X))
        -p      Convert from squared-loss X to abs-percent (100*sqrt(X))
        -Q      Convert from squared-loss X to (exp(sqrt(X)) - 1.0) * 100
        -q      Convert from quantile-loss X to (exp(X) - 1.0) * 100
        -l      Use since-last (2nd) instead of average-loss (1st) column
                Note: this option ignores the final avg loss
        -o<IMG> Output chart to <IMG> file
        -x<XL>  Use <XL> as X-axis label in chart
        -y<YL>  Use <YL> as Y-axis label in chart
        -t<T>   Use <T> as title of chart
        -w<W>   set image width in pixels to <W> (default $DefaultWidth)
        -h<H>   set image height in pixels to <H> (default $DefaultHeight)
        -s<S>   set point size to <S> (default $DefaultPointSize)

        -o <image_file> is optional, if not given, will create a temp-file:
            $TmpImgFile
        and display it.
";
}

#
# Map from squared-error to %percent error
#
sub sqlosslog2pct($) {
    my $arg = shift;
    # return 0 unless (defined $arg);
    (exp(sqrt($arg)) - 1.0) * 100;
}

#
# Map from quantile-error to %percent error
#
sub losslog2pct($) {
    my $arg = shift;
    # return 0 unless (defined $arg);
    (exp($arg) - 1.0) * 100;
}

sub vector_max(@) {
    my $max = 0;   # all our loss values are non-negative
    foreach my $val (@_) {
        $max = $val if (defined($val) && $val > $max);
    }
    $max;
}

sub transform_loss($) {
    my $loss = shift;

    return losslog2pct($loss) if ($opt_q);      # quantile on log scale
    return sqlosslog2pct($loss) if ($opt_Q);    # squared on log scale
    return sqrt($loss) if ($opt_a);             # squared to absolute
    return (100.0*sqrt($loss)) if ($opt_p);

    $loss;
}

#
# input: vw progress (can concatenate multiple) from STDIN or file
# output: a list of vectors of avg loss values
#
my $YMax = 0;
sub average_loss_arrays() {
    my @avg_losses_array = ();
    my @avg_loss = ();
    my $gap = 0;

    while (<>) {
        # progress lines
        my ($n1, $n2) = (/^([0-9]+\.[0-9]+)\s+([0-9]+\.[0-9]+)/);
        if (defined $n1) {
            push(@avg_loss, transform_loss($opt_l ? $n2 : $n1));
            next;
        }
        # summary line
        if (! $opt_l && /^average loss\s*=\s*([0-9.e+-]+)/i) {
            push(@avg_loss, transform_loss($1));

            if (@avg_loss) {
                push(@avg_losses_array, [ @avg_loss ]);
                $Ymax = vector_max($Ymax, @avg_loss);
            }
            @avg_loss = ();
            $gap = 0;
            next;
        }
        # 2 empty lines or more are also a data-set separator
        if (/^\s*$/ && ++$gap > 1) {
            if (@avg_loss) {
                push(@avg_losses_array, [ @avg_loss ]);
                $Ymax = vector_max($Ymax, @avg_loss);
            }
            @avg_loss = ();
            $gap = 0;
            next;
        }
    }
    # if summary line wasn't included, and we have avg_loss data,
    # use whatever we have
    if (@avg_loss) {
        $Ymax = vector_max($Ymax, @avg_loss);
        push(@avg_losses_array, [ @avg_loss ]);
        @avg_loss = ();
    }

    usage("Couldn't identify 'vw' progress report(s) in input")
        unless (@avg_losses_array);

    v("=== Found %d progress runs in input\n", scalar(@avg_losses_array));
    @avg_losses_array;
}

sub do_plot($;$) {
    my ($loss_vec_ref, $imgfile) = @_;

    my $set_r_device_line =
        ($imgfile =~ /\.e?ps$/) ?
            "postscript(file='$imgfile', width=$opt_w, height=$opt_h)"
        : ($imgfile =~ /\.jpg$/) ?
            "jpeg(file='$imgfile', width=$opt_w, height=$opt_h)"
        : # default: if you don't have 'Cairo', just use 'png' instead
            "library(graphics); autoload('CairoPNG', 'Cairo');\n" .
            "if (exists('Cairo', mode='function')) {\n" .
                "CairoPNG(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
            "} else {\n" .
                "png(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
            "}"
    ;

    my $R_input = "$set_r_device_line;\n";
    my $lineno = 0;
    my $cex = $opt_s;
    my @pchs = ();
    $R_input .= "colors = colors()[c(26, 554, 257, 115, 644, 132, 92)]\n";
    $R_input .= "col.list = vector()\n";
    foreach my $lossref (@$loss_vec_ref) {
        my $R_losses_array = sprintf('c(%s)', join(',', @$lossref));
        $R_input .= "loss = $R_losses_array\n" .
                    "color = colors[1 + ($lineno %% length(colors))]\n";
        $R_input .=
            ($lineno == 0) ?
                "par(mar=c(5,6,3.5,1))\n" .
                "plot(loss, pch=20, t='o', col=color, lwd=2, cex=$cex,\n" .
                "\tlab=c(20,20,7), xlab='$opt_x', ylab='$opt_y\n', las=1,\n" .
                "\tylim=c(0, $Ymax), cex.lab=1.5, font.lab=4,\n" .
                "\tpanel.first=grid(col='gray63'));\n"
            :
                "lines(loss, pch=20, t='o', col=color, lwd=2, cex=$cex);\n";

        $R_input .= "col.list[1+$lineno] = color\n";
        push(@namelist, sprintf("%.4f", $lossref->[-1]));
        push(@pchs, 20);
        $lineno++;
    }
    $R_input .= "title(main='$opt_t', cex.main=1.5, col='black')\n";

    $R_input .= sprintf(
        "legend('topright', " .
            "legend=c(%s), col=col.list, text.col=col.list, pch=c(%s), %s);\n",
            join(',', @namelist),
            join(',', @pchs),
            "inset=0.05, title='final mean $LossStr', pt.cex=2, lwd=2.5"
    );

    my $verbosity = $opt_v ? '--verbose' : '--slave -q --silent 2>/dev/null';
    open(RIN, "|R --vanilla --no-readline $verbosity");
    print RIN $R_input;
    close RIN;

    # For debugging
    if ($opt_v) {
        open(RSAV, ">$TmpRFile") || die "$0: $TmpRFile: $!\n";
        print RSAV $R_input;
        close RSAV;
        v("Wrote R file for debug: $TmpRFile\n");
    }
}

sub get_args() {
    $0 =~ s{.*/}{};
    getopts('dx:y:t:w:h:s:laqQpo:vi:') || usage();

    $opt_l = 0 unless ((defined $opt_l) && $opt_l > 0);
    $LossStr = ($opt_Q || $opt_p || $opt_q) ? '%loss' : 'loss';
    $LossStr .= ' since last' if ($opt_l);
    $opt_x = 'vw progress iteration (log-scale)' unless (defined $opt_x);
    $opt_y = "mean $LossStr"
        unless (defined $opt_y);
    $opt_t = "online training $opt_y convergence" unless (defined $opt_t);
    $opt_w = $DefaultWidth unless (defined $opt_w);
    $opt_h = $DefaultHeight unless (defined $opt_h);
    $opt_s = $DefaultPointSize unless (
            (defined $opt_s) and
            looks_like_number($opt_s) and
            ($opt_s > 0)
    );

    $opt_o = $TmpImgFile
        unless (defined $opt_o);

    unlink($TmpImgFile) if (-e $TmpImgFile);

    my @file_args = ();
    foreach my $arg (@ARGV) {
        if ($arg =~ /\.(?:png|jpe?g)$/) {
            $opt_o = $arg;
            next;
        }
        if (-f $arg) {
            push(@file_args, $arg);
        } else {
            usage("$0: $arg: $!");
        }
    }
    @ARGV = @file_args;
    if (-e $opt_o && $opt_o ne $TmpImgFile) {
        warn "$0: image file '$opt_o' already exists: moving to .prev\n";
        rename($opt_o, "$opt_o.prev") ||
            die "$0: rename($opt_o, $opt_o.prev) failed: $!\n";
    }
}

sub display($) {
    my $imgfile = $_[0];

    find_image_viewer();

    die "$0: display($imgfile): $! - must be a bug\n"
        unless (-e $imgfile);

    if ($opt_d) {
        # User requested no display,
        # Be friendly, give a hint that everything is OK
        printf STDERR "image file is: %s\n", $imgfile;
        return;
    }
    # Postscript files are generated in portrait: need 90-degrees rotation
    # 'display' supports a rotate arg, YMMV
    my $rotate = ($imgfile =~ /ps$/ && $DisplayProg =~ /display$/)
                    ? '-rotate 90'
                    : '';
    $opt_i = '' unless (defined $opt_i);
    system("$DisplayProg $opt_i $rotate $imgfile");
}


# -- main
get_args();
my @losses = average_loss_arrays();
do_plot(\@losses, $opt_o);
display($opt_o);

