#!/usr/bin/perl -w
#
# vim: sw=4 ts=4 expandtab smarttab
#
# vw-varinfo      Summarize features of a training-set using VW
#   Input:          A vw training set file
#   Output:         A list of features, their VW hash values, min/max
#                   values, regressor weights, and distance from
#                   the best constant.
#
# Algorithm:
#   1)  Collect all variables and their ranges from training-set
#   2)  Train with VW to determine regressor weights
#   3)  Build a test-set with a single example including all variables
#   4)  run VW with --audit on 3) to map variable names to hash values
#       and weights.
#   5)  Output collected information about the input variables.
#
# NOTE: distance from the best constant predictor is not really
# variable 'importance', in the sense that it is not a reliable indicator
# of prediction performance in general.  For example, if you have two
# equal co-occurring features, the sum of the weights is what matters
# and individual weight values are arbitrary.
#
# Distance represents the relative distance of the regressor weight
# from the weight of the 'zero' which is vw's best constant prediction.
#
# Merits of this distance:
#   1) It is a relative metric, making it easier to compare two features
#   2) Easier to interpret vs an arbitrary weight (max is set to 100%)
#   2) It is signed, so it shows which features are positively vs
#      negatively correlated with the target label.
#
# (c) 2012 - ariel faigon for vowpal-wabbit
# This software may be distributed under the same terms as vowpal-wabbit
#
# use Getopt::Std;
use vars (qw($opt_v $opt_V $opt_O $opt_K));

my $VW = 'vw';
# default vw options are more optimal now, don't try to make them so
my $VWARGS = '';

my ($TrainFile, $ModelFile, $RmodelFile, $ExampleFile, $AuditFile);
my (%FeatureMax, %FeatureMin);

my (%Feature2Hash, %Feature2Weight);

my %NameSpace;      # hash of hashes: namespace => { key => val }...
my @Features;       # feature names list

my @QPairs = ();    # list of pairs ([a, b], [c, d] ...) for namespace pairing
my %Ignore;         # support for --ignore X
my %Keep;           # support for --keep X
my $DoKeep;         # flag for whether we need to use --keep or not
my $MultiClass = 0; # flag for multi-class (--oaa --csoaa --wap* --sequence?)
my %Labels;
my @Labels = (1);   # List of labels for super example generation

my %Label2FW;       # for multi-class: every label has feature->weight
my %Prediction;     # prediction of each isolated multi-class label

my @TmpFiles;

#
# v() & V():  verbose/debug prints for -v and -V
#
sub v(@) {
    return unless $opt_v;
    if (@_ == 1) {
        print STDERR @_;
    } else {
        printf STDERR @_;
    }
}

sub V(@) {
    return unless $opt_V;
    if (@_ == 1) {
        print STDERR @_;
    } else {
        printf STDERR @_;
    }
}

sub usage(@) {
    print STDERR @_, "\n" if (@_);
    die "Usage: $0 [options] [vw-options] <training-set-file>
    Options:
        -v          verbose
        -V          more verbose
        -K          keep temporary files
        -O<which>   Use order/ranking metric <which>
                    Supported metrics:
                        ... not implemented yet ...

    vw-options:
        Note that all the above options do not clash with vw options
        All other options will be passed as-is to the vw training step.

    See the script source head comments for more details.
";
}

sub get_args {
    $0 =~ s{.*/}{};

    if (-f $ARGV[-1]) {
        $TrainFile = pop(@ARGV);
    } else {
        usage("last command line arg must be a training-set file");
    }

    my @vw_opts_and_args = ();
    foreach my $arg (@ARGV) {
        if ($arg =~ /^-[vVKOP]+$/) {
            # These options are for us, not for vw
            $opt_v = 1 if ($arg =~ /v/);
            $opt_V = 1 if ($arg =~ /V/);
            $opt_K = 1 if ($arg =~ /K/);
            $opt_v = 1 if ($opt_V);
            if ($arg =~ /O/) {
                ($opt_O) = ("@ARGV" =~ /$arg\s+(\S+)\b/);
            }
            if ($arg =~ /P/) {
                usage("-P: option no longer supported.\n" .
                       "Please pass VW options directly.\n");
            }
            next;
        }
        if (-f $arg) {
            my $skip_ts = 0;
            if ($vw_opts_and_args[-1] =~ '-d|--data') {
                pop(@vw_opts_and_args);
                $skip_ts = 1;
            # These options have file-args following them
            } elsif ($vw_opts_and_args[-1] !~
                        /^(?:
                            -p
                            |--predictions
                            |-i
                            |--initial_regressor
                            |-f
                            |--final_regressor
                            |--feature_mask
                            |-r
                            |--raw_predictions
                            |--cache_file
                            |--pid_file
                            |--readable_model
                            |--output_feature_\S+
                        )$/x) {
                $skip_ts = 1;
            }
            if ($skip_ts) {
                warn "ignoring trainset: $arg in vw-args: train-set must be last arg\n";
                next;
            }
        }
        push(@vw_opts_and_args, $arg);
    }
    $opt_O = '' unless (defined $opt_O);
    usage("You must supply a training-set file")
        unless (defined $TrainFile);

    usage("training-set file: $TrainFile: $!")
        unless (-f $TrainFile);

    if (@vw_opts_and_args) {
        $VWARGS = "@vw_opts_and_args";
    }
    while ($VWARGS =~ /-q\s*(\S)(\S)/g) {
        push(@QPairs, [$1, $2]);
    }
    while ($VWARGS =~ /--keep\s*(\S)/g) {
        $DoKeep = 1;
        $Keep{$1} = 1;
    }
    $Keep{''} = 1;          # to be consistent with vw no-namespce behavior
    while ($VWARGS =~ /--ignore\s*(\S)/g) {
        $Ignore{$1} = 1;
    }
    if ($VWARGS =~ /--(?:(?:cs)?oaa|wap|ect|sequence)/) {
        if ($VWARGS =~ /--(?:wap|ect)/) {
            # Please send a patch when/if you can figure these out
            die "$0: --wap, --ect multi-class is not supported - sorry\n";
        }
        $MultiClass = 1;
    }

    # Since we need to pass '-f model', ensure it doesn't clash
    # with the user explicitly passing these
    if ($VWARGS =~ /(?:(?:^|\s)(?:-f|--final_regressor)\s+(\S+))/) {
        $ModelFile = $1;
    } else {
        $ModelFile = "$TrainFile.model";
        push(@TmpFiles, $ModelFile);
    }

    $ExampleFile = "$TrainFile.examples";
    $AuditFile = "$TrainFile.audit";

    push(@TmpFiles, $ExampleFile, $AuditFile);
}

sub cleanup {
    if ($opt_K) {
        v("keeping temporary files: @TmpFiles\n");
        return;
    }
    foreach my $tmpfile (@TmpFiles) {
        unlink($tmpfile);
    }
}


#
# symbolic full feature (name-space + feature-name) as used by --audit
#
sub feature_name(@) {
    join('^', @_);
}

#
# pair_features()
#   Initialize %FeatureMin and %FeatureMax for all paired
#   name-spaces based on the @QPairs list which was constructed
#   from VW -q ... arguments.
#
#   Respect all --ignore and --keep logic while doing so.
#
sub pair_features {
    my %paired_features;
    my @name_spaces = keys %NameSpaces;
    my (@matching_ns1, @matching_ns2);

    foreach my $pair_ref (@QPairs) {
        my ($x, $y) = @$pair_ref;
        foreach my $ns1 (@name_spaces) {
            my $ns1_ch = substr($ns1,0,1);
            if ($x eq $ns1_ch) {
                push(@matching_ns1, $ns1)
                    if (exists($Keep{$x}) or !exists($Ignore{$x}));
            }
        }
        foreach my $ns2 (@name_spaces) {
            my $ns2_ch = substr($ns2,0,1);
            if ($y eq $ns2_ch) {
                push(@matching_ns2, $ns2)
                    if (exists($Keep{$y}) or !exists($Ignore{$y}));
            }
        }
    }
    foreach my $ns1 (@matching_ns1) {
        foreach my $ns2 (@matching_ns2) {
            my $nsref1 = $NameSpaces{$ns1};
            my $nsref2 = $NameSpaces{$ns2};
            foreach my $key1 (keys %$nsref1) {
                foreach my $key2 (keys %$nsref2) {
                    my $feature = feature_name($ns1, $key1, $ns2, $key2);
                    $FeatureMax{$feature} = 0;
                    $FeatureMin{$feature} = 0;
                }
            }
        }
    }
}

sub parse_labels($) {
    my $labels = shift;
    $labels =~ s/\s+\S+$//; # trim optional tag (touching the '|')
    while ($labels =~ /([^:\s]+):?(\S+)?/g) {
        # match labels and optional weights
        $Labels{$1} = (defined $2) ? $2 : 1;
    }
    sort {$a <=> $b} keys %Labels;
}

#
# read_features($trainingset_file)
#   Read the training set & parse it, collect all name-spaces,
#   feature-names, min/max values
#
sub read_features($) {
    my ($trainset) = @_;

    my $ts;
    if ($trainset =~ /\.gz$/) {
        open($ts, "gunzip -c $trainset|") || die "$0: gunzip -c $trainset|: $!\n";
    } else {
        open($ts, $trainset) || die "$0: $trainset: $!\n";
    }
    while (<$ts>) {
        # -- examples loop
        next unless (/\S/);     # skip empty lines

        # -- grab anything following the 1st '|'
        my ($labels, $input_features) = ($_ =~ /^([^|]*)\|(.*)$/);
        die "$0: $trainset line $.: malformed example: missing '|'\n"
            unless (defined $input_features);

        if ($MultiClass) {
            @Labels = parse_labels($labels);
        }

        my @name_space_region = split('\|', $input_features);
        foreach my $nsr (@name_space_region) {
            # -- name-spaces loop (note: name-space my be ''):
            #    extract the name-space string, ignore (optional) :weight
            my ($ns) = ($nsr =~ /^([^:\s]+)(?:\:\S+)?/);
            $ns = '' unless ((defined $ns) && length($ns));

            my $ns_ch1 = substr($ns, 0, 1);

            next if (exists $Ignore{$ns_ch1});
            next if ($DoKeep && !exists $Keep{$ns_ch1});

            unless (exists $NameSpaces{$ns}) {
                $NameSpaces{$ns} = {};
            }
            my $nsref = $NameSpaces{$ns};

            # Trim (the optionally empty) name-space prefix,
            # including the optional :weight
            $nsr =~ s/^\Q$ns\E\S*\s*//;

            # Following the name-space: loop over feature+value pairs:
            foreach my $keyval (split(/\s+/, $nsr)) {
                # -- features loop
                my ($key, $val);
                if ($keyval =~ /:/) {       # explicit :value
                    ($key, $val) = ($`, $');
                } else {                    # implicit value == 1
                    $key = $keyval;
                    $val = 1;
                }
                $nsref->{$key} = $val;

                my $f = feature_name($ns, $key);

                # -- record min, max per feature
                unless (exists $FeatureMax{$f}) {
                    $FeatureMax{$f} = 0;
                    $FeatureMin{$f} = 0;
                }
                if ($FeatureMax{$f} < $val) {
                    $FeatureMax{$f} = $val;
                }
                if ($FeatureMin{$f} > $val) {
                    $FeatureMin{$f} = $val;
                }
            }
        }
    }
    close $ts;

    # Add the -q pairs of features
    pair_features();

    # Add the Constant feature
    $FeatureMin{'Constant'} = $FeatureMax{'Constant'} = 0;
}

#
# do_train
#   run the training stage to compute per feature weights
#
sub do_train($$) {
    my ($trainset, $model) = @_;
    my $cmd = "$VW --quiet $VWARGS";
    unless ($cmd =~ /(?:\s(?:-d|--data)\s)/) {
        $cmd .= " -d $trainset";
    }
    unless ($cmd =~ /(?:\s(?:-f|--final_regressor))/) {
        $cmd .= " -f $model";
    }

    if ($opt_v) {
        $cmd =~ s/ --quiet / /;
    }

    v("training: %s\n", $cmd);
    system($cmd);
    die "$0: vw training failed (see details above)\n"
        unless ($? == 0);
}

sub generate_one_example($$) {
    my ($fd, $label) = @_;

    if ($MultiClass) {
        printf $fd "%s:1", $label;
        # foreach $label2 (@Labels) {
        #    next if ($label eq $label2);
        #    printf $fd " %s:0", $label2;
        # }
    } else {
        # simple, non multi-class case
        print $fd $label;
    }
    # print all possible input features, with a weight of 1
    foreach my $ns (keys %NameSpaces) {
        my $nsref = $NameSpaces{$ns};
        printf $fd ' |%s', $ns;
        foreach my $key (sort keys %$nsref) {
            my $weight = 1;
            printf $fd ' %s:%s', $key, $weight;
        }
    }
    print $fd "\n";
}

sub generate_examples($) {
    my ($example_file) = shift;
    open(my $fd, ">$example_file") ||
        die "$0: can't write full_example file: '$example_file': $!\n";

    v("Labels: @Labels\n");
    foreach my $label (@Labels) {
        # One line per label:
        # multiclass deprecates to singleton: label=1
        generate_one_example($fd, $label);
    }
    close $fd;
}

my %SeenFeatureNames;
my $MCLabel;
my $MCLabelIndex = -1;

sub audit_one_example($) {
    my $audit_stream = shift;

    # skip the prediction line
    # we're only interested in the audit line
    my $prediction = <$audit_stream>;
    my $features_data = <$audit_stream>;

    my $weight_href;
    if ($MultiClass) {
        if (++$MCLabelIndex >= @Labels) {
            $MCLabelIndex =0;
        }
        $MCLabel = $Labels[$MCLabelIndex];
        $weight_href = $Label2FW{$MCLabel} = {};
        chomp($prediction);
        $Prediction{$MCLabel} = $prediction;
    }

    chomp($features_data);
    my @features_list = split(' ', $features_data);

    while (@features_list) {
        my $audited_item = shift @features_list;
        next unless ($audited_item);

        # Audited feature format:   namespace^varname:142703:1:0.0435613 ...
        my (@fields) = split(':', $audited_item);

        my ($feature, $hashval, $value, $weight) = @fields[-4 .. -1];

        # Trim '@0' and similar suffixes.
        $weight =~ s/@.*$//;

        unless ($feature) {
            if ($MultiClass) {
                $feature = "Constant_$MCLabel";
                $FeatureMax{$feature} = 0;
                $FeatureMin{$feature} = 0;
            }
        }

        $SeenFeatureNames{$feature} = 1;
        $Feature2Hash{$feature} = $hashval;

        if ($MultiClass) {
            # v("audit_one_example: MC=$MultiClass Label=$MCLabel {$feature} = $weight\n");
            $weight_href->{$feature} = $weight;
        } else {
            $Feature2Weight{$feature} = $weight;
        }
        V("%s\t%s\t%s\t%s\n", $feature, $hashval, $value, $weight);
    }
}

#
# audit_features()
#   read the output of vw -a (audit) on the all-feature example
#   to extract hash values and weights
#   Return the list of all feature-names
#
sub audit_features {
    generate_examples($ExampleFile);

    # Bug in vw multiclass, looks like we need to pass the multiclass
    # params to --audit even though they should be in the model
    my $vw_audit_args = "--quiet -t --audit -i $ModelFile -d $ExampleFile";
    my $vw_mcargs = '';
    if (${VWARGS} =~ /--(?:(?:cs)?oaa|wap|sequence)(?:_ldf)?\s+\d+/) {
#        $vw_mcargs = $&;
    }
    my $audit_cmd = "$VW $vw_mcargs $vw_audit_args";
    $audit_cmd .= "|tee $AuditFile" if ($opt_K);

    open(my $audit_stream, "$audit_cmd |") ||
        die "$0: can't run \$audit_cmd: '$audit_cmd |'\n";

    while (!eof($audit_stream)) {
        audit_one_example($audit_stream);
    }

    close $audit_stream;

    # Return the list of features actually seen in the audit
    sort keys %SeenFeatureNames;
}


#
# return 'score' metric for a given feature
#
sub score($$;$) {
    my ($class_label, $feature, $metric) = @_;

    my $f2w_hashref = $MultiClass
                        ? $Label2FW{$class_label}
                        : \%Feature2Weight;

    my $fweight = $f2w_hashref->{$feature};
    unless (defined $fweight) {
        warn "$0: BUG?: score($class_label, $feature): undef weight!\n";
        return undef;
    }

    # Support more metrics, are there any others that make sense?
    # if ($metric eq '...') ...

    $fweight;
}

#
# Find maximum feature-name length and max/min values
#
sub feature_flen_min_max($@) {
    my $w_href = shift @_;

    my ($max_flen, $min_weight, $max_weight) = (10, 0, 0);

    foreach my $f (@_) {
        my $w = $w_href->{$f};
        unless (defined $w) {
            # Should already be caught in score() above,
            # so warn only in verbose mode
            v("%s: feature_flen_min_max: %s: undefined weight\n", $0, $f);
            next;
        }
        my $flen = length($f);
        $max_flen = $flen if ($flen > $max_flen);

        $max_weight = $w if ($w > $max_weight);
        $min_weight = $w if ($w < $min_weight);
    }
    ($max_flen, $min_weight, $max_weight);
}

#
# Find min/max score and zero the score of the constant feature
#
sub feature_min_max_score($@) {
    my $class_label = shift @_;

    my ($min_score, $max_score) = (0, 0);
    my $score_href = {};    # feature->score

    foreach my $f (@_) {            # features loop
        if ($f =~ /^Constant/) {
            my $constant_feature_name = $f;
            $score_href->{$f} = 0;
            next;
        }

        my $score = score($class_label, $f, $opt_O);
        next unless (defined $score);
        $score_href->{$f} = $score;

        $max_score = $score if ($score > $max_score);
        $min_score = $score if ($score < $min_score);
    }
    ($min_score, $max_score, $score_href);
}

sub print_feature_report($$$$@) {
    my ($max_flen, $max_distance_from_0,
        $class_label, $score_href, @features) = @_;

    my $weight_href = $MultiClass
                        ? $Label2FW{$class_label}
                        : \%Feature2Weight;

    printf "%-*s\t%+10s %8s %8s %+9s %+10s\n",
                $max_flen, 'FeatureName', 'HashVal',
                'MinVal', 'MaxVal',
                'Weight', 'RelScore';

    # TODO? support different orders:
    # by weight, by feature-range-normalized weights?
    foreach my $f (sort {
                     $score_href->{$b} <=> $score_href->{$a}
                  } @features) {

        my $score = $score_href->{$f};
        my $distance_from_0 = $score;
        $max_distance_from_0 = 1e-10 if ($max_distance_from_0 == 0);
        my $normalized_score = $distance_from_0 / $max_distance_from_0;

        # FIXME: support different normalization schemes
        if ($opt_O =~ /^a/i) {
            $normalized_score = abs($normalized_score);
        }

        # On the fly generated (e.g. -q crossed features) don't have a min/max
        $FeatureMin{$f} = 0 unless (exists $FeatureMin{$f});
        $FeatureMax{$f} = 0 unless (exists $FeatureMax{$f});

        printf "%-*s\t%10u %8.2f %8.2f %+9.4f %9.2f%%\n",
                $max_flen, $f,
                $Feature2Hash{$f},
                $FeatureMin{$f},
                $FeatureMax{$f},
                $weight_href->{$f},
                (100.0 * $normalized_score);
    }
    print "\n" if ($MultiClass);
}


#
# summarize_features
#   Output what we know about all features + relative score
#
sub summarize_features {
    # Per-class loop for multi-class,
    # only one loop for non multi-class
    foreach my $label (@Labels) {
        my @features = ($MultiClass) ? (keys %{$Label2FW{$label}}) : @Features;

        my ($min_score, $max_score, $score_href) =
                            feature_min_max_score($label, @features);
        my $score_range = $max_score - $min_score;

        my ($max_flen, $min_wt, $max_wt) =
                            feature_flen_min_max($score_href, @features);
        my $range_weight = $max_wt - $min_wt;

        my $upper_range = abs($max_score);
        my $lower_range = abs($min_score);

        my $max_distance_from_0 = ($upper_range > $lower_range)
                                        ? $upper_range
                                        : $lower_range;

        printf "=== Class Label: %s\tPrediction: %s\n",
                $label, $Prediction{$label}
                    if ($MultiClass);

        print_feature_report($max_flen, $max_distance_from_0,
                             $label, $score_href, @features);
    }
}

# -- main
get_args();
read_features($TrainFile);
do_train($TrainFile, $ModelFile);
@Features = audit_features();
summarize_features();
cleanup();

