# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the PyMVPA package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Cross-validation of a learner's performance

A learner is repeatedly trained and tested on partitions of an input
dataset that are generated by a configurable partitioning scheme.
Partition usually constitute training and testing portions.  The learner
is trained on training portion of the dataset and then learner's
generalization is tested by comparing its predictions on the testing portion.

A summary of a learner performance is written to STDOUT. Depending on the
particular setup of the cross-validation analysis, either the learner's raw
predictions or summary statistics are returned in an output dataset.

If Monte-Carlo permutation testing is enabled (see --permutations) a second
output dataset with the corresponding p-values is stored as well (filename
suffix '_nullprob').

"""

# magic line for manpage summary
# man: -*- % cross-validation of a learner's performance

__docformat__ = 'restructuredtext'

import numpy as np
import copy
import sys
import argparse
from mvpa2.base import verbose, warning, error
from mvpa2.datasets import vstack
if __debug__:
    from mvpa2.base import debug
from mvpa2.cmdline.helpers \
    import parser_add_common_opt, \
           ds2hdf5, hdf2ds, learner_opt, partitioner_opt, \
           learner_space_opt, arg2errorfx, get_crossvalidation_instance, \
           crossvalidation_opts_grp


parser_args = {
    'formatter_class': argparse.RawDescriptionHelpFormatter,
}

def setup_parser(parser):
    from .helpers import parser_add_optgroup_from_def, \
        parser_add_common_attr_opts, single_required_hdf5output
    parser_add_common_opt(parser, 'multidata', required=True)
    # make learner and partitioner options required
    cv_opts_grp = copy.deepcopy(crossvalidation_opts_grp)
    for i in (0,2):
        cv_opts_grp[1][i][1]['required'] = True
    parser_add_optgroup_from_def(parser, cv_opts_grp)
    parser_add_optgroup_from_def(parser, single_required_hdf5output)

def run(args):
    dss = hdf2ds(args.data)
    verbose(3, 'Loaded %i dataset(s)' % len(dss))
    ds = vstack(dss)
    verbose(3, 'Concatenation yielded %i samples with %i features' % ds.shape)
    # get CV instance
    cv = get_crossvalidation_instance(
            args.learner, args.partitioner, args.errorfx, args.sampling_repetitions,
            args.learner_space, args.balance_training, args.permutations,
            args.avg_datafold_results, args.prob_tail)
    res = cv(ds)
    # some meaningful output
    # XXX make condition on classification analysis only?
    print(cv.ca.stats)
    print('Results\n-------')
    if args.permutations > 0:
        nprob =  cv.ca.null_prob.samples
    if res.shape[1] == 1:
        # simple result structure
        if args.permutations > 0:
            p=', p-value (%s tail)' % args.prob_tail
        else:
            p=''
        print('Fold, Result%s' % p)
        for i in range(len(res)):
            if args.permutations > 0:
                p = ', %f' % nprob[i, 0]
            else:
                p = ''
            print('%s, %f%s' % (res.sa.cvfolds[i], res.samples[i, 0], p))
    # and store
    ds2hdf5(res, args.output, compression=args.hdf5_compression)
    if args.permutations > 0:
        if args.output.endswith('.hdf5'):
            args.output = args.output[:-5]
        ds2hdf5(cv.ca.null_prob, '%s_nullprob' % args.output,
                compression=args.hdf5_compression)
    return res
