## Copyright (C) 2025 Andreas Bertsatos <abertsatos@biol.uoa.gr>
## Copyright (C) 2025 Avanish Salunke <avanishsalunke16@gmail.com>
##
## This file is part of the statistics package for GNU Octave.
##
## This program is free software; you can redistribute it and/or modify it under
## the terms of the GNU General Public License as published by the Free Software
## Foundation; either version 3 of the License, or (at your option) any later
## version.
##
## This program is distributed in the hope that it will be useful, but WITHOUT
## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
## details.
##
## You should have received a copy of the GNU General Public License along with
## this program; if not, see <http://www.gnu.org/licenses/>.

classdef cvpartition
  ## -*- texinfo -*-
  ## @deftp {statistics} cvpartition
  ##
  ## Partition data for cross-validation
  ##
  ## The @code{cvpartition} class generates a partitioning scheme on a dataset
  ## to facilitate cross-validation of statistical models utilizing training and
  ## testing subsets of the dataset.
  ##
  ## @seealso{crossval}
  ## @end deftp

  properties (GetAccess = public, SetAccess = private)
    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} NumObservations
    ##
    ## Number of observations
    ##
    ## A positive integer scalar specifying the number of observations in the
    ## dataset (including any missing data, where applicable).  This property
    ## is read-only.
    ##
    ## @end deftp
    NumObservations = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} NumTestSets
    ##
    ## Number of test sets
    ##
    ## A positive integer scalar specifying the number of folds for partition
    ## types @qcode{'kfold'} and @qcode{'leaveout'}.  When partition type is
    ## @qcode{'holdout'} and @qcode{'resubstitution'}, then @qcode{NumTestSets}
    ## is 1.  This property is read-only.
    ##
    ## @end deftp
    NumTestSets     = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} TrainSize
    ##
    ## Size of each train set
    ##
    ## A positive integer scalar specifying the size of the train set for
    ## partition types @qcode{'holdout'} and @qcode{'resubstitution'} or a
    ## vector of positive integers specifying the size of each training set for
    ## partition types @qcode{'kfold'} and @qcode{'leaveout'}.  This property
    ## is read-only.
    ##
    ## @end deftp
    TrainSize       = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} TestSize
    ##
    ## Size of each test set
    ##
    ## A positive integer scalar specifying the size of the test set for
    ## partition types @qcode{'holdout'} and @qcode{'resubstitution'} or a
    ## vector of positive integers specifying the size of each testing set for
    ## partition types @qcode{'kfold'} and @qcode{'leaveout'}.  This property
    ## is read-only.
    ##
    ## @end deftp
    TestSize        = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} Type
    ##
    ## Type of validation partition
    ##
    ## A character vector specifying the type of the @qcode{cvpartition} object.
    ## It can be @qcode{kfold}, @qcode{holdout}, @qcode{leaveout}, or
    ## @qcode{resubstitution}.  This property is read-only.
    ##
    ## @end deftp
    Type            = '';

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} IsCustom
    ##
    ## Flag for custom partition
    ##
    ## A logical scalar specifying whether the @qcode{cvpartition} object
    ## was created using custom partition partitioning (@qcode{true}) or
    ## not (@qcode{false}).  This property is read-only.
    ##
    ## @end deftp
    IsCustom        = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} IsGrouped
    ##
    ## Flag for grouped partition
    ##
    ## A logical scalar specifying whether the @qcode{cvpartition} object was
    ## created using grouping variables (@qcode{true}) or not (@qcode{false}).
    ## This property is read-only.
    ##
    ## @end deftp
    IsGrouped       = [];

    ## -*- texinfo -*-
    ## @deftp {cvpartition} {property} IsStratified
    ##
    ## Flag for stratified partition
    ##
    ## A logical scalar specifying whether the @qcode{cvpartition} object was
    ## created with a @qcode{'stratify'} value of @qcode{true}.
    ## This property is read-only.
    ##
    ## @end deftp
    IsStratified    = [];

  endproperties

  properties (Access = private, Hidden)
    missidx = [];
    indices = [];
    cvptype = '';
    classes = [];
    classID = [];
    grpvars = [];
  endproperties

  methods (Hidden)

    ## Custom display
    function display (this)
      in_name = inputname (1);
      if (! isempty (in_name))
        fprintf ('%s =\n', in_name);
      endif
      disp (this);
    endfunction

    ## Custom display
    function disp (this)
      fprintf ("\n%s\n", this.cvptype);
      ## Print selected properties
      fprintf ("%+25s: %d\n", 'NumObservations', this.NumObservations);
      fprintf ("%+25s: %d\n", 'NumTestSets', this.NumTestSets);
      vlen = numel (this.TrainSize);
      if (vlen <= 10)
        str = repmat ({"%d"}, 1, vlen);
        str = strcat ('[', strjoin (str, ' '), ']');
        str1 = sprintf (str, this.TrainSize);
        str2 = sprintf (str, this.TestSize);
      else
        str = repmat ({"%d"}, 1, 10);
        str = strcat ('[', strjoin (str, ' '), ' ... ]');
        str1 = sprintf (str, this.TrainSize(1:10));
        str2 = sprintf (str, this.TestSize(1:10));
      endif
      fprintf ("%+25s: %s\n", 'TrainSize', str1);
      fprintf ("%+25s: %s\n", 'TestSize', str2);
      fprintf ("%+25s: %d\n", 'IsCustom', this.IsCustom);
      fprintf ("%+25s: %d\n", 'IsGrouped', this.IsGrouped);
      fprintf ("%+25s: %d\n\n", 'IsStratified', this.IsStratified);
    endfunction

    ## Class specific subscripted reference
    function varargout = subsref (this, s)
      chain_s = s(2:end);
      s = s(1);
      t = "Invalid %s indexing for referencing values in a cvpartition object.";
      switch (s.type)
        case '()'
          error (t, '()');
        case '{}'
          error (t, '{}');
        case '.'
          if (! ischar (s.subs))
            error (strcat ("cvpartition.subsref: '.' indexing", ...
                           " argument must be a character vector."));
          endif
          try
            out = this.(s.subs);
          catch
            error ("cvpartition.subref: unrecognized property: '%s'", s.subs);
          end_try_catch
      endswitch
      ## Chained references
      if (! isempty (chain_s))
        out = subsref (out, chain_s);
      endif
      varargout{1} = out;
    endfunction

    ## Class specific subscripted assignment
    function this = subsasgn (this, s, val)
      if (numel (s) > 1)
        error (strcat ("cvpartition.subsasgn:", ...
                       " chained subscripts not allowed."));
      endif
      t = "Invalid %s indexing for assigning values to a cvpartition object.";
      switch s.type
        case '()'
          error (t, '()');
        case '{}'
          error (t, '{}');
        case '.'
          if (! ischar (s.subs))
            error (strcat ("cvpartition.subsasgn: '.' indexing", ...
                           " argument must be a character vector."));
          endif
          error (strcat ("cvpartition.subsasgn: unrecognized", ...
                         " or read-only property: '%s'"), s.subs);
      endswitch
    endfunction

  endmethods

  methods (Access = public)

    ## -*- texinfo -*-
    ## @deftypefn  {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'KFold'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'KFold'}, @var{k})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'KFold'}, @var{k}, @qcode{'GroupingVariables'}, @var{grpvars})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'Holdout'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'Holdout'}, @var{p})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'Leaveout'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{n}, @qcode{'Resubstitution'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'KFold'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'KFold'}, @var{k})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'KFold'}, @var{k}, @qcode{'Stratify'}, @var{opt})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'Holdout'})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'Holdout'}, @var{p})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@var{X}, @qcode{'Holdout'}, @var{p}, @qcode{'Stratify'}, @var{opt})
    ## @deftypefnx {cvpartition} {@var{C} =} cvpartition (@qcode{'CustomPartition'}, @var{testSets})
    ##
    ## Repartition data for cross-validation.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'KFold'})} creates a
    ## @qcode{cvpartition} object @var{C}, which defines a random nonstratified
    ## partition for k-fold cross-validation on @var{n} observations with each
    ## fold (subsample) having approximately the same number of observations.
    ## The default number of folds is 10 for @code{@var{n} >= 10} or equal to
    ## @var{n} otherwise.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'KFold'}, @var{k})} also
    ## creates a nonstratified random partition for k-fold cross-validation with
    ## the number of folds defined by @var{k}, which must be a positive integer
    ## scalar smaller than the number of observations @var{n}.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'KFold'}, @var{k},
    ## @qcode{'GroupingVariables'}, @var{grpvars})} creates a @qcode{cvpartition}
    ## object @var{C} that defines a random partition for k-fold cross-validation
    ## with each fold containing the same combination of group labels as defined
    ## by @var{grpvars}.  The grouping variables specified in @var{grpvars} can
    ## be one of the following:
    ##
    ## @itemize
    ## @item A numeric vector, logical vector, categorical vector, character
    ## array, string array, or cell array of character vectors containing one
    ## grouping variable.
    ## @item A numeric matrix or cell array containing two or more grouping
    ## variables. Each column in the matrix or array must correspond to one
    ## grouping variable.
    ## @end itemize
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'Holdout'})} creates a
    ## @qcode{cvpartition} object @var{C}, which defines a random nonstratified
    ## partition for holdout validation on @var{n} observations.  90% of the
    ## observations are assigned to the training set and the remaining 10% to
    ## the test set.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'Holdout'}, @var{p})} also
    ## creates a nonstratified random partition for holdout validation with the
    ## percentage of training and test sets defined by @var{p}, which can be a
    ## scalar value in the range @math{(0,1)} or a positive integer scalar in
    ## the range @math{[1,@var{n})}.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'Leaveout'})} creates a
    ## @qcode{cvpartition} object @var{C}, which defines a random partition for
    ## leave-one-out cross-validation on @var{n} observations.  This is a
    ## special case of k-fold cross-validation with the number of folds equal to
    ## the number of observations.
    ##
    ## @code{@var{C} = cvpartition (@var{n}, @qcode{'Resubstitution'})} creates
    ## a @qcode{cvpartition} object @var{C} without partitioning the data and
    ## both training and test sets containing all observations @var{n}.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'KFold'})} creates a
    ## @qcode{cvpartition} object @var{C}, which defines a stratified random
    ## partition for k-fold cross-validation according to the class proportions
    ## in @var{Χ}.  @var{X} can be a numeric, logical, categorical, or string
    ## vector, or a character array or a cell array of character vectors.
    ## Missing values in @var{X} are discarded.  The default number of folds is
    ## 10 for @code{numel (@var{X}) >= 10} or equal to @code{numel (@var{X})}
    ## otherwise.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'KFold'}, @var{k})} also
    ## creates a stratified random partition for k-fold cross-validation with
    ## the number of folds defined by @var{k}, which must be a positive integer
    ## scalar smaller than the number of observations in @var{X}.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'KFold'}, @var{k},
    ## @qcode{'Stratify'}, @var{opt})} creates a random partition for k-fold
    ## cross-validation, which is stratified if @var{opt} is @qcode{true}, or
    ## nonstratified if @var{opt} is @qcode{false}.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'Holdout'})} creates a
    ## @qcode{cvpartition} object @var{C}, which defines a stratified random
    ## partition for holdout validation while maintaining the class proportions
    ## in @var{Χ}.  90% of the observations are assigned to the training set and
    ## the remaining 10% to the test set.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'Holdout'}, @var{p})} also
    ## creates a stratified random partition for holdout validation with the
    ## percentage of training and test sets defined by @var{p}, which can be a
    ## scalar value in the range @math{(0,1)} or a positive integer scalar in
    ## the range @math{[1,@var{n})}.
    ##
    ## @code{@var{C} = cvpartition (@var{X}, @qcode{'Holdout'}, @var{p},
    ## @qcode{'Stratify'}, @var{opt})} creates a random partition for holdout
    ## validation, which is stratified if @var{opt} is @qcode{true}, or
    ## nonstratified if @var{opt} is @qcode{false}.
    ##
    ## @code{@var{C} = cvpartition (@qcode{'CustomPartition'}, @var{testSets})}
    ## creates a custom partition according to @var{testSets}, which can be a
    ## positive integer vector, a logical vector, or a logical matrix according
    ## to the following options:
    ## @itemize
    ## @item A positive integer vector of length @math{N} with values in the
    ## range @math{[1,K]}, where @math{K < N}, will specify a K-fold
    ## cross-validation partition, in which each value indicates the test set
    ## of each observation.  Alternatively, the same vector with values in the
    ## range @math{[1,N]} will specify a leave-one-out cross-validation.
    ## @item A logical vector will specify a holdout validation, in which the
    ## @qcode{true} elements correspond to the test set and the @qcode{false}
    ## elements correspond to the training set.
    ## @item A logical matrix with @math{K} columns will specify a K-fold
    ## cross-validation partition, in which each column corresponds to a fold
    ## and each row to an observation.  Alternatively, an @math{NxN} logical
    ## matrix will specify a leave-one-out cross-validation, where @math{N} is
    ## the number of observations.  @qcode{true} elements correspond to the
    ## test set and the @qcode{false} elements correspond to the training set.
    ## @end itemize
    ##
    ## @seealso{cvpartition, summary, test, training}
    ## @end deftypefn

    function this = cvpartition (X, varargin)

      ## Check for appropriate number of input arguments
      if (nargin < 2)
        error ("cvpartition: too few input arguments.");
      endif
      if (nargin > 5)
        error ("cvpartition: too many input arguments.");
      endif

      ## Check for custom partition
      if (strcmpi (X, "CustomPartition"))
        testSets = varargin{1};
        ## Check for valid test set
        if (! (isnumeric (testSets) || islogical (testSets)))
          error ("cvpartition: TESTSETS must be numeric of logical.");
        endif
        if (isnumeric (testSets))
          if (! isvector (testSets))
            error ("cvpartition: TESTSETS must be a numeric vector.");
          endif
          [~, idx, inds] = unique (testSets);
          this.NumObservations = numel (testSets);
          this.NumTestSets = numel (idx);
          nvec = this.NumObservations * ones (1, this.NumTestSets);
          if (this.NumTestSets < this.NumObservations)
            this.indices = inds;
            for i = 1:this.NumTestSets
              this.TestSize(i) = sum (inds == i);
            endfor
            this.TrainSize = nvec - this.TestSize;
            this.Type = 'kfold';
            this.cvptype = 'K-fold cross validation partition';
          else
            this.TrainSize = nvec - 1;
            this.TestSize = nvec - this.TrainSize;
            this.Type = 'leaveout';
            this.cvptype = 'Leave-one-out cross validation partition';
          endif
        else  # logical vector of matrix
          if (! ismatrix (testSets))
            error ("cvpartition: TESTSETS must be a logical vector or matrix.");
          elseif (isvector (testSets))
            this.NumObservations = numel (testSets);
            this.NumTestSets = 1;
            this.indices = testSets;
            this.TrainSize = sum (! testSets);
            this.TestSize = sum (testSets);
            this.Type = 'holdout';
            this.cvptype = 'Hold-out cross validation partition';
          else  # logical matrix
            ## Each observation must be present in exactly one test set
            if (any (sum (testSets, 2) > 1))
              error (strcat ("cvpartition: each observation in TESTSETS", ...
                             " must be exactly one in each row."));
            endif
            [this.NumObservations, this.NumTestSets] = size (testSets);
            nvec = this.NumObservations * ones (1, this.NumTestSets);
            if (this.NumTestSets < this.NumObservations)
              this.indices = zeros (this.NumObservations, 1);
              for i = 1:this.NumTestSets
                this.TestSize(i) = sum (testSets(:,i));
                this.indices(testSets(:,i)) = i;
              endfor
              this.TrainSize = nvec - this.TestSize;
              this.Type = 'kfold';
              this.cvptype = 'K-fold cross validation partition';
            elseif (this.NumTestSets == this.NumObservations)
              this.TrainSize = nvec - 1;
              this.TestSize = nvec - this.TrainSize;
              this.Type = 'leaveout';
              this.cvptype = 'Leave-one-out cross validation partition';
            else
              error (strcat ("cvpartition: a logical matrix in TESTSETS", ...
                             " must not have more columns that rows."));
            endif
          endif
        endif
        this.IsCustom = true;
        this.IsGrouped = false;
        this.IsStratified = false;

      ## Check first input being a scalar value
      elseif (isscalar (X))
        if (! (isnumeric (X) && X > 0 && fix (X) == X))
          error ("cvpartition: X must be a scalar positive integer value.");
        endif
        ## Get number of observations and partition type
        this.NumObservations = X;
        type = varargin{1};
        this.IsCustom = false;
        this.IsStratified = false;

        ## "Resubstitution"
        if (strcmpi (type, 'resubstitution'))
          this.NumTestSets = 1;
          this.TrainSize = X;
          this.TestSize = X;
          this.Type = 'resubstitution';
          this.cvptype = 'Resubstitution (no partition of data)';
          this.IsGrouped = false;

        ## "Leaveout"
        elseif (strcmpi (type, 'leaveout'))
          this.NumTestSets = X;
          this.TrainSize = (X - 1) * ones (1, X);
          this.TestSize = ones (1, X);
          this.Type = 'leaveout';
          this.cvptype = 'Leave-one-out cross validation partition';
          this.IsGrouped = false;

        ## "Holdout"
        elseif (strcmpi (type, 'holdout'))
          if (nargin > 2)
            p = varargin{2};
            if (! isnumeric (p) || ! isscalar (p))
              error (strcat ("cvpartition: P value for 'holdout'", ...
                             " must be a numeric scalar."));
            endif
            if (! ((p > 0 && p < 1) || (p == fix (p) && p > 0 && p < X)))
              error (strcat ("cvpartition: P value for 'holdout' must be", ...
                             " a scalar in the range (0,1) or an integer", ...
                             " scalar in the range [1, N)."));
            endif
          else
            p = 0.1;
          endif
          this.NumTestSets = 1;
          if (p < 1)            # target fraction to sample
            p = round (p * X);  # number of samples
          endif
          inds = false (X, 1);
          inds(randsample (X, p)) = true;  # indices for test set
          this.indices = inds;
          this.TrainSize = sum (! inds);
          this.TestSize = sum (inds);
          this.Type = 'holdout';
          this.cvptype = 'Hold-out cross validation partition';
          this.IsGrouped = false;

        ## "KFold"
        elseif (strcmpi (type, 'kfold'))
          this.Type = 'kfold';
          if (nargin > 2)
            k = varargin{2};
            if (! isnumeric (k) || ! isscalar (k))
              error (strcat ("cvpartition: K value for 'kfold'", ...
                             " must be a numeric scalar."));
            endif
          else
            if (X < 10)
              k = X;
            else
              k = 10;
            endif
          endif
          ## No grouping variables
          if (nargin < 4)
            if (! (k == fix (k) && k > 0 && k <= X))
              error (strcat ("cvpartition: K value for 'kfold' must be", ...
                             " an integer scalar in the range [1, N]."));
            endif
            this.NumTestSets = k;
            indices = floor ((0:(X - 1))' * (k / X)) + 1;
            indices = randsample (indices, X);
            nvec = X * ones (1, k);
            for i = 1:k
              this.TestSize(i) = sum (indices == i);
            endfor
            this.indices = indices;
            this.TrainSize = nvec - this.TestSize;
            this.cvptype = 'K-fold cross validation partition';
            this.IsGrouped = false;
          else  # with grouping variables
            if (! strcmpi (varargin{3}, 'groupingvariables'))
              error (strcat ("cvpartition: invalid optional paired", ...
                             " argument for 'GroupingVariables'."));
            endif
            if (nargin < 5)
              error (strcat ("cvpartition: missing value for optional", ...
                             " paired argument 'GroupingVariables'."));
            endif
            grpvars = varargin{4};
            if (isvector (grpvars))
              ## Remove any missing values
              this.missidx = ismissing (grpvars);
              if (any (this.missidx))
                grpvars(this.missidx) = [];
                X -= sum (this.missidx);
              endif
              ## Get indices for each group
              if (isa (grpvars, 'categorical'))
                [~, idx, inds] = unique (grpvars, 'stable');
              else
                [~, idx, inds] = __unique__ (grpvars, 'stable');
              endif
            elseif (ismatrix (grpvars))
              ## Remove any missing values
              this.missidx = any (ismissing (grpvars), 2);
              if (any (this.missidx))
                grpvars(this.missidx, :) = [];
                X -= sum (this.missidx);
              endif
              ## Get indices for each group
              if (isa (grpvars, 'categorical'))
                [~, idx, inds] = unique (grpvars, 'rows', 'stable');
              else
                [~, idx, inds] = __unique__ (grpvars, 'rows', 'stable');
              endif
            else
              error (strcat ("cvpartition: invalid value for optional", ...
                             " paired argument 'GroupingVariables'."));
            endif
            if (X != numel (inds))
              error (strcat ("cvpartition: grouping variable does", ...
                             " not match the number of observations."));
            endif
            this.grpvars = grpvars;
            ## Get number of groups and group sizes
            NumGroups = numel (idx);
            for i = 1:NumGroups
              GroupSize(i) = sum (inds == i);
            endfor
            ## Compare k-fold to number of groups and reduce K accordingly
            if (k > NumGroups)
              warning (strcat ("cvpartition: number of folds K is greater", ...
                               " than the groups in 'GroupingVariables'.", ...
                               " K is set to the number of groups."));
                k = NumGroups;
            endif
            ## If k == NumGroups, then each group becomes a test in a fold.
            ## If k < NumGroups, then cluster NumGroups to k folds.
            indices = zeros (X, 1);
            if (k == NumGroups)
              for i = 1:k
                indices(inds == i) = i;
              endfor
            else
              [GroupIdx, ~, GroupSz] = multiway (GroupSize, k, 'completeKK');
              for i = 1:k
                idxGV = find (GroupIdx == i);
                vecGV = arrayfun(@(x) x == inds, idxGV, "UniformOutput", false);
                index = vecGV{1};
                if (numel (vecGV) > 1)
                  for j = 2:numel (vecGV)
                    index = index | vecGV{j};
                  endfor
                endif
                indices(index) = i;
              endfor
            endif
            ## Randomize the order of folds
            random_idx = randsample ([1:k], k);
            randomized = zeros (size (inds));
            for i = 1:k
              randomized(indices == i) = random_idx(i);
            endfor
            ## Save values to properties
            this.indices = randomized;
            this.NumTestSets = k;
            nvec = X * ones (1, k);
            for i = 1:k
              this.TestSize(i) = sum (this.indices == i);
            endfor
            this.TrainSize = nvec - this.TestSize;
            this.cvptype = 'Group K-fold cross validation partition';
            this.IsGrouped = true;
          endif

        ## Invalid paired argument
        else
          error ("cvpartition: invalid optional paired argument.");
        endif

      ## Check first input being a vector for stratification
      elseif (isvector (X))
        ## Get number of observations (including missing values)
        this.NumObservations = numel (X);

        ## Remove missing values from partitioning.
        ## Keep missing index to include them in the test indices.
        this.missidx = ismissing (X);
        X(this.missidx) = [];

        ## Get stratify option
        if (nargin < 4)
          this.IsStratified = true;
        else
          if (! strcmpi (varargin{3}, 'stratify'))
              error (strcat ("cvpartition: invalid optional paired", ...
                             " argument for stratification."));
          endif
          if (nargin < 5)
            error (strcat ("cvpartition: missing value for optional", ...
                           " paired argument 'stratify'."));
          endif
          if (! isscalar (varargin{4}) || ! islogical (varargin{4}))
            error (strcat ("cvpartition: invalid value for optional", ...
                           " paired argument 'stratify'."));
          endif
          this.IsStratified = varargin{4};
        endif

        ## Handle stratification
        if (this.IsStratified)
          [classID, idx, classes] = unique (X);
          NumClasses = numel (idx);
          for i = 1:NumClasses
            ClassSize(i) = sum (classes == i);
          endfor
          this.classes = classes;
          this.classID = classID;
        endif
        X = numel (X);

        ## Get partition type
        type = varargin{1};
        this.IsCustom = false;
        this.IsGrouped = false;

        ## "Holdout"
        if (strcmpi (type, 'holdout'))
          this.Type = 'holdout';
          if (nargin > 2)
            p = varargin{2};
            if (! isnumeric (p) || ! isscalar (p))
              error (strcat ("cvpartition: P value for 'holdout'", ...
                             " must be a numeric scalar."));
            endif
            if (! ((p > 0 && p < 1) || (p == fix (p) && p > 0 && p < X)))
              error (strcat ("cvpartition: P value for 'holdout' must be", ...
                             " a scalar in the range (0,1) or an integer", ...
                             " scalar in the range [1, N), where N is the", ...
                             " number of nonmissing observations in X."));
            endif
          else
            p = 0.1;
          endif
          this.NumTestSets = 1;
          if (this.IsStratified)
            if (p < 1)
              f = p;              # target fraction to sample
              p = round (p * X);  # number of test samples
            else
              f = p / X;
            endif
            inds = zeros (X, 1, "logical");
            k_check = 0;
            for i = 1:NumClasses
              ki = round (f * ClassSize(i));
              inds(find (classes == i)(randsample (ClassSize(i), ki))) = true;
              k_check += ki;
            endfor
            if (k_check < p)      # add random elements to test set to make it p
              inds(find (! inds)(randsample (X - k_check, p - k_check))) = true;
            elseif (k_check > p)  # remove random elements from test set
              inds(find (inds)(randsample (k_check, k_check - p))) = false;
            endif
            this.cvptype = 'Stratified hold-out cross validation partition';
          else
            if (p < 1)            # target fraction to sample
              p = round (p * X);  # number of samples
            endif
            inds = false (X, 1);
            inds(randsample (X, p)) = true;  # indices for test set
            this.cvptype = 'Hold-out cross validation partition';
          endif
          this.indices = inds;
          this.TrainSize = sum (! inds);
          this.TestSize = sum (inds);

        ## "KFold"
        elseif (strcmpi (type, 'kfold'))
          this.Type = 'kfold';
          if (nargin > 2)
            k = varargin{2};
            if (! isnumeric (k) || ! isscalar (k))
              error (strcat ("cvpartition: K value for 'kfold'", ...
                             " must be a numeric scalar."));
            endif
            if (! (k == fix (k) && k > 0 && k <= X))
              error (strcat ("cvpartition: K value for 'kfold' must be", ...
                             " an integer scalar in the range [1, N],", ...
                             " where N is the number of nonmissing", ...
                             " observations in X."));
            endif
          else
            if (X < 10)
              k = X;
            else
              k = 10;
            endif
          endif
          this.NumTestSets = k;
          if (this.IsStratified)
            inds = nan (X, 1);
            pooled_idx = false (X, 1);
            do_warn = true;
            do_ceil = false;
            for i = 1:NumClasses
              cls_size = ClassSize(i);
              cls_k_eq = fix (cls_size / k) == (cls_size / k);
              ## Check that the elements in each class exceed the number of
              ## requested folds, otherwise emit a warning and add the class
              ## elements into a pooled class
              if (cls_size < k)
                if (do_warn)
                  warning (strcat ("One or more of the unique class values", ...
                                   " in the stratification variable is not", ...
                                   " present in one or more folds."));
                  do_warn = false;
                endif
                pooled_idx = pooled_idx | classes == i;
              elseif (fix (X / k) == X / k)
                ## Make sure that when X / k = integer, all
                ## test/training sizes must be equal across all folds
                if (do_ceil && ! cls_k_eq)
                  idx = ceil ((0:(cls_size - 1))' * (k / cls_size));
                  idx(idx == 0) = max (idx);
                  do_ceil = false;
                else
                  idx = floor ((0:(cls_size - 1))' * (k / cls_size)) + 1;
                  tmp = arrayfun (@(x) numel (find (x == idx)), [1:k]);
                  if (any (diff (tmp)))
                    do_ceil = true;
                  endif
                endif
                inds(classes == i) = randsample (idx, cls_size);
              else
                ## Alternate ordering over classes so that
                ## the subsets are more nearly the same size
                if (! do_ceil || cls_k_eq)
                  idx = floor ((0:(cls_size - 1))' * (k / cls_size)) + 1;
                  if (! cls_k_eq)
                    do_ceil = true;
                  endif
                else
                  idx = floor (((cls_size - 1):-1:0)' * (k / cls_size)) + 1;
                  do_ceil = false;
                endif
                inds(classes == i) = randsample (idx, cls_size);
              endif
            endfor
            ## Stratify pooled classes (if any).  They must be distributed
            ## in a way to make the test/training sizes as equal as possible
            ## across folds.
            pooled_inds = find (pooled_idx);
            while (numel (pooled_inds) > 0)
              tmp = arrayfun (@(x) numel (find (x == inds)), [1:k]);
              [min_cls, min_idx] = min (tmp);
              [max_cls, max_idx] = max (tmp);
              if (min_cls != max_cls)
                inds(pooled_inds(1)) = min_idx;
              else
                inds(pooled_inds(1)) = randsample (k, 1);
              endif
              pooled_inds(1) = [];
            endwhile
            this.cvptype = 'Stratified K-fold cross validation partition';
          else
            inds = floor ((0:(X - 1))' * (k / X)) + 1;
            inds = randsample (inds, X);
            this.cvptype = 'K-fold cross validation partition';
          endif
          this.indices = inds;
          nvec = X * ones (1, k);
          for i = 1:k
            this.TestSize(i) = sum (inds == i);
          endfor
          this.TrainSize = nvec - this.TestSize;

        ## Invalid paired argument
        else
          error ("cvpartition: invalid optional paired argument.");
        endif

      ## Otherwise first input is invalid
      else
        error ("cvpartition: invalid first input argument.");
      endif

    endfunction

    ## -*- texinfo -*-
    ## @deftypefn  {cvpartition} {@var{Cnew} =} repartition (@var{C})
    ## @deftypefnx {cvpartition} {@var{Cnew} =} repartition (@var{C}, @var{sval})
    ## @deftypefnx {cvpartition} {@var{Cnew} =} repartition (@var{C}, @qcode{'legacy'})
    ##
    ## Repartition data for cross-validation.
    ##
    ## @code{@var{Cnew} = repartition (@var{C})} creates a @qcode{cvpartition}
    ## object @var{Cnew} that defines a new random partition of the same type as
    ## the @qcode{cvpartition} @var{C}.
    ##
    ## @code{@var{Cnew} = repartition (@var{C}, @var{sval})} also uses the value
    ## of @var{sval} to set the state of the random generator used in
    ## repartitioning @var{C}.  If @var{sval} is a vector, then the random
    ## generator is set using the @qcode{"state"} keyword as in
    ## @code{rand ("state", @var{sval})}.  If @var{sval} is a scalar, then the
    ## @qcode{"seed"} keyword is used as in @code{rand ("seed", @var{sval})} to
    ## specify that old generators should be used.
    ##
    ## @code{@var{Cnew} = repartition (@var{C}, @qcode{'legacy'})} only applies
    ## to @qcode{cvpartition} objects @var{C} that use k-fold partitioning and
    ## it will repartition @var{C} in the same non-random manner that was
    ## previously used by the old-style @qcode{cvpartition} class of the
    ## statistics package.  The @qcode{'legacy'} option does not apply to
    ## stratified or grouped partitions.
    ##
    ## @seealso{cvpartition, summary, test, training}
    ## @end deftypefn

    function this = repartition (this, sval = [])

      ## Emit error for custom partitions
      if (this.IsCustom)
        error ("cvpartition.repartition: cannot repartition a custom partition.");
      endif

      ## Handle legacy code with no randomization of kfold option
      if (strcmpi (sval, "legacy"))
        if (strcmpi (this.Type, "kfold"))
          X = this.NumObservations;
          k = this.NumTestSets;
          if (! (this.IsGrouped || this.IsStratified))
            inds = floor ((0:(X - 1))' * (k / X)) + 1;
            this.indices = inds;
            nvec = X * ones (1, k);
            for i = 1:k
              this.TestSize(i) = sum (inds == i);
            endfor
            this.TrainSize = nvec - this.TestSize;
          else  # legacy option does not apply for grouped or stratified
            error (strcat ("cvpartition.repartition: 'legacy' flag does", ...
                           " not apply to stratified or grouped 'kfold'", ...
                           " partitioned objects."));
          endif
          return;
        else
          error (strcat ("cvpartition.repartition: 'legacy' flag is only", ...
                         " valid for 'kfold' partitioned objects."));
        endif
      endif

      ## Check sval
      if (! isempty (sval))
        if (! (isvector (sval) && isnumeric (sval) && isreal (sval)))
          error (strcat ("cvpartition.repartition: SVAL must be", ...
                         " a real scalar or vector."));
        endif
        if (isscalar (sval))
          rand ("sval", sval);
        else
          rand ("state", sval);
        endif
      endif

      ## Handle repartitioning of randomized holdout and kfold options
      if (strcmpi (this.Type, "holdout"))
        p = this.TestSize;
        if (this.IsStratified)
          X = sum (! this.missidx);
          inds = false (X, 1);
          NumClasses = numel (this.classID);
          classes = this.classes;
          for i = 1:NumClasses
            ClassSize(i) = sum (classes == i);
          endfor
          f = p / X;
          k_check = 0;
          for i = 1:NumClasses
            ki = round (f * ClassSize(i));
            inds(find (classes == i)(randsample (ClassSize(i), ki))) = true;
            k_check += ki;
          endfor
          if (k_check < p)      # add random elements to test set to make it p
            inds(find (! inds)(randsample (X - k_check, p - k_check))) = true;
          elseif (k_check > p)  # remove random elements from test set
            inds(find (inds)(randsample (k_check, k_check - p))) = false;
          endif
        else
          X = this.NumObservations;
          inds = false (X, 1);
          inds(randsample (X, p)) = true;  # indices for test set
        endif
        this.indices = inds;

      elseif (strcmpi (this.Type, "kfold"))
        k = this.NumTestSets;
        if (! (this.IsGrouped || this.IsStratified))
          X = this.NumObservations;
          inds = floor ((0:(X - 1))' * (k / X)) + 1;
          inds = randsample (inds, X);
          this.indices = inds;
          nvec = X * ones (1, k);
          for i = 1:k
            this.TestSize(i) = sum (inds == i);
          endfor
          this.TrainSize = nvec - this.TestSize;
        elseif (this.IsGrouped)
          ## We only need resample the order of folds in this case
          ## Randomize the order of folds
          random_idx = randsample ([1:k], k);
          randomized = zeros (size (this.indices));
          for i = 1:k
            randomized(this.indices == i) = random_idx(i);
          endfor
          ## Save values to properties
          this.indices = randomized;
          this.NumTestSets = k;
          nvec = sum (! this.missidx) * ones (1, k);
          for i = 1:k
            this.TestSize(i) = sum (this.indices == i);
          endfor
          this.TrainSize = nvec - this.TestSize;
        else  # is stratified
          X = sum (! this.missidx);
          NumClasses = numel (this.classID);
          classes = this.classes;
          for i = 1:NumClasses
            ClassSize(i) = sum (classes == i);
          endfor
          inds = nan (X, 1);
          pooled_idx = false (X, 1);
          do_warn = true;
          do_ceil = false;
          for i = 1:NumClasses
            cls_size = ClassSize(i);
            cls_k_eq = fix (cls_size / k) == (cls_size / k);
            ## Check that the elements in each class exceed the number of
            ## requested folds, otherwise emit a warning and add the class
            ## elements into a pooled class
            if (cls_size < k)
              if (do_warn)
                warning (strcat ("One or more of the unique class values", ...
                                 " in the stratification variable is not", ...
                                 " present in one or more folds."));
                do_warn = false;
              endif
              pooled_idx = pooled_idx | classes == i;
            elseif (fix (X / k) == X / k)
              ## Make sure that when X / k = integer, all
              ## test/training sizes must be equal across all folds
              if (do_ceil && ! cls_k_eq)
                idx = ceil ((0:(cls_size - 1))' * (k / cls_size));
                idx(idx == 0) = max (idx);
                do_ceil = false;
              else
                idx = floor ((0:(cls_size - 1))' * (k / cls_size)) + 1;
                tmp = arrayfun (@(x) numel (find (x == idx)), [1:k]);
                if (any (diff (tmp)))
                  do_ceil = true;
                endif
              endif
              inds(classes == i) = randsample (idx, cls_size);
            else
              ## Alternate ordering over classes so that
              ## the subsets are more nearly the same size
              if (! do_ceil || cls_k_eq)
                idx = floor ((0:(cls_size - 1))' * (k / cls_size)) + 1;
                if (! cls_k_eq)
                  do_ceil = true;
                endif
              else
                idx = floor (((cls_size - 1):-1:0)' * (k / cls_size)) + 1;
                do_ceil = false;
              endif
              inds(classes == i) = randsample (idx, cls_size);
            endif
          endfor
          ## Stratify pooled classes (if any).  They must be distributed
          ## in a way to make the test/training sizes as equal as possible
          ## across folds.
          pooled_inds = find (pooled_idx);
          while (numel (pooled_inds) > 0)
            tmp = arrayfun (@(x) numel (find (x == inds)), [1:k]);
            [min_cls, min_idx] = min (tmp);
            [max_cls, max_idx] = max (tmp);
            if (min_cls != max_cls)
              inds(pooled_inds(1)) = min_idx;
            else
              inds(pooled_inds(1)) = randsample (k, 1);
            endif
            pooled_inds(1) = [];
          endwhile
          this.indices = inds;
          nvec = X * ones (1, k);
          for i = 1:k
            this.TestSize(i) = sum (inds == i);
          endfor
          this.TrainSize = nvec - this.TestSize;
        endif
      endif

    endfunction

    ## -*- texinfo -*-
    ## @deftypefn {cvpartition} {@var{tbl} =} summary (@var{c})
    ##
    ## Summarize stratified or grouped cross-validation partitions.
    ##
    ## @code{@var{tbl} = summary (@var{c})} returns a summary table @var{tbl} of
    ## the validation partition contained in the @code{cvpartition} object
    ## @var{c}.
    ##
    ## This method calculates the distribution of classes (if stratified) or
    ## groups (if grouped) across the entire dataset, as well as within every
    ## training and test set generated by the partition.
    ##
    ## @subheading Inputs
    ## @itemize
    ## @item @var{c}
    ## A @code{cvpartition} object.  The object must satisfy two conditions:
    ## @enumerate
    ## @item The partition type (@code{c.Type}) must be @qcode{"kfold"} or
    ## @qcode{"holdout"}.
    ## @item The partition must be created with a stratification or grouping
    ## variable (i.e., @code{c.IsStratified} or @code{c.IsGrouped} must be
    ## @code{true}).
    ## @end enumerate
    ## @end itemize
    ##
    ## @subheading Outputs
    ## @itemize
    ## @item @var{tbl}
    ## A @code{table} object containing the summary statistics.  The table
    ## contains one row for every unique label/group in every set (all, train,
    ## test).  The columns are:
    ## @table @code
    ## @item Set
    ## The specific subset being described.  Values include @qcode{"all"} (the
    ## full dataset), @qcode{"train1"}, @qcode{"test1"}, etc.
    ## @item SetSize
    ## The total number of observations in that specific set.
    ## @item Label
    ## The class or group identifier.  If @code{c.IsStratified} is true, this
    ## column is named @code{StratificationLabel}.  If @code{c.IsGrouped} is
    ## true, it is named @code{GroupLabel}.
    ## @item Count
    ## The number of observations of that label within the set.  If stratified,
    ## this column is named @code{StratificationCount}; otherwise,
    ## @code{GroupCount}.
    ## @item PercentInSet
    ## The percentage of the set composed of that specific label.
    ## @end table
    ## @end itemize
    ##
    ## @seealso{cvpartition, repartition, test, training}
    ## @end deftypefn

   function tbl = summary (this)

      ## Validation Checks
      if (! (this.IsStratified || this.IsGrouped))
        error ("cvpartition.summary: partition must be stratified or grouped.");
      endif

      if (! (strcmpi (this.Type, 'kfold') || strcmpi (this.Type, 'holdout')))
        error ("cvpartition.summary: partition type must be 'kfold' or 'holdout'.");
      endif

      ## Prepare Labels and Data Map
      if (this.IsStratified)
        LabelVarName = 'StratificationLabel';
        CountVarName = 'StratificationCount';
        UniqueLabels = this.classID;
        DataMap = this.classes;
      else
        ## Grouped
        LabelVarName = 'GroupLabel';
        CountVarName = 'GroupCount';
        ## Use __unique__ internal helper to ensure stable rows
        if (isa (this.grpvars, 'categorical'))
          [UniqueLabels, ~, DataMap] = unique (this.grpvars, 'rows', 'stable');
        else
          [UniqueLabels, ~, DataMap] = __unique__ (this.grpvars, 'rows', 'stable');
        endif
      endif

      ## Calculate dimensions for preallocation
      NumLabels = size (UniqueLabels, 1);
      NumSets = 1 + (2 * this.NumTestSets); ## 1 ("all") + 2 * K (Train/Test)
      TotalRows = NumLabels * NumSets;

      ## Preallocate Columns
      col_Set = cell (TotalRows, 1);
      col_SetSize = zeros (TotalRows, 1);
      col_Count = zeros (TotalRows, 1);
      col_Percent = zeros (TotalRows, 1);

      ## Determine if Label column is text or numeric
      if (iscell (UniqueLabels) || isstring (UniqueLabels) ||
                                   ischar (UniqueLabels))
        col_Label = cell (TotalRows, 1);
        is_text_label = true;
      else
        col_Label = zeros (TotalRows, 1);
        is_text_label = false;
      endif

      ## Helper for populating data
      curr_idx = 1;

      ## Inline helper function to calculate stats
      function [c_set, c_size, c_lbl, c_cnt, c_pct, idx_next] = ...
               fill_rows (name, mask, map, u_lbl, n_lbl, ...
                          c_set, c_size, c_lbl, c_cnt, c_pct, idx_start, is_txt)

        subset_map = map(mask);
        subset_size = numel (subset_map);

        for u = 1:n_lbl
          count = sum (subset_map == u);

          c_set{idx_start} = name;
          c_size(idx_start) = subset_size;

          if (is_txt)
            if (iscell (u_lbl))
              c_lbl{idx_start} = u_lbl{u};
            elseif (isstring (u_lbl))
              ## Convert string object to char for cell storage
              c_lbl{idx_start} = char (u_lbl(u));
            else
              c_lbl{idx_start} = u_lbl(u, :);
            endif
          else
            c_lbl(idx_start) = u_lbl(u);
          endif

          c_cnt(idx_start) = count;
          c_pct(idx_start) = (count / subset_size) * 100;

          idx_start = idx_start + 1;
        endfor
        idx_next = idx_start;
      endfunction

      ## Calculate Statistics

      ## --- Set: "all" ---
      all_mask = true (size (DataMap));
      [col_Set, col_SetSize, col_Label, col_Count, col_Percent, curr_idx] = ...
          fill_rows ('all', all_mask, DataMap, UniqueLabels, NumLabels, ...
                     col_Set, col_SetSize, col_Label, col_Count, col_Percent, ...
                     curr_idx, is_text_label);

      ## --- Set: Folds ---
      for k = 1:this.NumTestSets
        if (strcmpi (this.Type, 'holdout'))
          test_mask = this.indices;
        else
          test_mask = (this.indices == k);
        endif

        train_name = sprintf ('train%d', k);
        [col_Set, col_SetSize, col_Label, col_Count, col_Percent, curr_idx] = ...
          fill_rows (train_name, !test_mask, DataMap, UniqueLabels, NumLabels, ...
                     col_Set, col_SetSize, col_Label, col_Count, col_Percent, ...
                     curr_idx, is_text_label);

        test_name = sprintf ('test%d', k);
        [col_Set, col_SetSize, col_Label, col_Count, col_Percent, curr_idx] = ...
          fill_rows (test_name, test_mask, DataMap, UniqueLabels, NumLabels, ...
                     col_Set, col_SetSize, col_Label, col_Count, col_Percent, ...
                     curr_idx, is_text_label);
      endfor

      ## Construct Table
      if (exist ('string', 'class'))
        col_Set = string (col_Set);
        if (is_text_label)
          col_Label = string (col_Label);
        endif
      endif

      tbl = table (col_Set, col_SetSize, col_Label, col_Count, col_Percent, ...
                   'VariableNames', {'Set', 'SetSize', LabelVarName, ...
                                     CountVarName, 'PercentInSet'});

    endfunction

    ## -*- texinfo -*-
    ## @deftypefn  {cvpartition} {@var{idx} =} test (@var{C})
    ## @deftypefnx {cvpartition} {@var{idx} =} test (@var{C}, @var{i})
    ## @deftypefnx {cvpartition} {@var{idx} =} test (@var{C}, @qcode{"all"})
    ##
    ## Test indices for cross-validation.
    ##
    ## @code{@var{idx} = test (@var{C})} returns a logical vector @var{idx} with
    ## @qcode{true} values indicating the elements corresponding to the test
    ## set defined in the @qcode{cvpartition} object @var{C}.  For K-fold and
    ## leave-one-out partitions, the indices corresponding to the first test set
    ## are returned.
    ##
    ## @code{@var{idx} = test (@var{C}, @var{i})} returns a logical vector or
    ## matrix with the indices of the test set indicated by @var{i}.  If @var{i}
    ## is a scalar, then @var{idx} is a logical vector with the indices of the
    ## @math{i-th} set.  If @var{i} is a vector, then @var{idx} is a logical
    ## matrix in which @code{@var{idx}(:,j)} specified the observations in the
    ## test set @code{@var{i}(j)}.  The value(s) in @var{i} must not exceed the
    ## number of tests in the @qcode{cvpartition} object @var{C}.
    ##
    ## @code{@var{idx} = test (@var{C}, @qcode{"all"})} returns a logical vector
    ## or matrix for all test sets defined in the @qcode{cvpartition} object
    ## @var{C}.  For holdout and resubstitution partition types, a vector is
    ## returned.  For K-fold and leave-one-out, a matrix is returned.
    ##
    ## @seealso{cvpartition, repartition, summary, training}
    ## @end deftypefn

    function idx = test (this, varargin)

      ## Check for sufficient input arguments
      if (nargin > 2)
        error ("cvpartition.test: too many input arguments.");
      elseif (nargin == 2)
        i = varargin{1};
        if (strcmpi (i, "all"))
          idx = logical ([]);
          switch (this.Type)
            case "kfold"
              for i = 1:this.NumTestSets
                if (this.IsStratified || this.IsGrouped)
                  cid = false (this.NumObservations, 1);
                  cid(! this.missidx) = this.indices == i;
                else
                  cid = this.indices == i;
                endif
                idx = [idx, cid];
              endfor
            case "leaveout" # no stratification
              for i = 1:this.NumTestSets
                cid = false (this.NumObservations, 1);
                cid(i) = true;
                idx = [idx, cid];
              endfor
            case "holdout"
              if (this.IsStratified)
                idx = false (this.NumObservations, 1);
                idx(! this.missidx) = this.indices;
              else
                idx = this.indices;
              endif
              idx = this.indices;
            case "resubstitution" # no stratification
              idx = true (this.NumObservations, 1);
          endswitch
          return
        elseif (isempty (i))
          i = 1;
        endif
      else
        i = 1;
      endif

      if (! (isvector (i) && isnumeric (i) &&
             all (fix (i) == i) && all (i > 0)))
        error ("cvpartition.test: set index must be a positive integer vector.");
      elseif (any (i > this.NumTestSets))
        error ("cvpartition.test: set index exceeds 'NumTestSets'.");
      endif

      switch (this.Type)
        case  "kfold"
          if (isscalar (i))
            if (this.IsStratified || this.IsGrouped)
              idx = false (this.NumObservations, 1);
              idx(! this.missidx) = this.indices == i;
            else
              idx = this.indices == i;
            endif
          else
            idx = logical ([]);
            for j = i
              if (this.IsStratified || this.IsGrouped)
                cid = false (this.NumObservations, 1);
                cid(! this.missidx) = this.indices == i;
              else
                cid = this.indices == i;
              endif
              idx = [idx, cid];
            endfor
          endif
        case "leaveout" # no stratification
          if (isscalar (i))
            idx = false (this.NumObservations, 1);
            idx(i) = true;
          else
            idx = logical ([]);
            for j = i
              new = false (this.NumObservations, 1);
              new(j) = true;
              idx = [idx, new];
            endfor
          endif
        case "holdout"
          if (this.IsStratified)
            idx = false (this.NumObservations, 1);
            idx(! this.missidx) = this.indices;
          else
            idx = this.indices;
          endif
        case "resubstitution" # no stratification
          idx = true (this.NumObservations, 1);
      endswitch

    endfunction

    ## -*- texinfo -*-
    ## @deftypefn  {cvpartition} {@var{idx} =} training (@var{C})
    ## @deftypefnx {cvpartition} {@var{idx} =} training (@var{C}, @var{i})
    ## @deftypefnx {cvpartition} {@var{idx} =} training (@var{C}, @qcode{"all"})
    ##
    ## Training indices for cross-validation.
    ##
    ## @code{@var{idx} = training (@var{C})} returns a logical vector @var{idx}
    ## with @qcode{true} values indicating the elements corresponding to the
    ## training set defined in the @qcode{cvpartition} object @var{C}.  For
    ## K-fold and leave-one-out partitions, the indices corresponding to the
    ## first training set are returned.
    ##
    ## @code{@var{idx} = training (@var{C}, @var{i})} returns a logical vector
    ## or matrix with the indices of the training set indicated by @var{i}.  If
    ## @var{i} is a scalar, then @var{idx} is a logical vector with the indices
    ## of the @math{i-th} set.  If @var{i} is a vector, then @var{idx} is a
    ## logical matrix in which @code{@var{idx}(:,j)} specified the observations
    ## in the training set @code{@var{i}(j)}.  The value(s) in @var{i} must not
    ## exceed the number of tests in the @qcode{cvpartition} object @var{C}.
    ##
    ## @code{@var{idx} = training (@var{C}, @qcode{"all"})} returns a logical
    ## vector or matrix for all training sets defined in the @qcode{cvpartition}
    ## object @var{C}.  For holdout and resubstitution partition types, a vector
    ## is returned.  For K-fold and leave-one-out, a matrix is returned.
    ##
    ## @seealso{cvpartition, repartition, summary, test}
    ## @end deftypefn

    function idx = training (this, varargin)

      ## Check for sufficient input arguments
      if (nargin > 2)
        error ("cvpartition.training: too many input arguments.");
      elseif (nargin == 2)
        i = varargin{1};
        if (strcmpi (i, "all"))
          idx = logical ([]);
          switch (this.Type)
            case "kfold"
              for i = 1:this.NumTestSets
                if (this.IsStratified || this.IsGrouped)
                  cid = false (this.NumObservations, 1);
                  cid(! this.missidx) = this.indices != i;
                else
                  cid = this.indices != i;
                endif
                idx = [idx, cid];
              endfor
            case "leaveout" # no stratification
              for i = 1:this.NumTestSets
                cid = true (this.NumObservations, 1);
                cid(i) = false;
                idx = [idx, cid];
              endfor
            case "holdout"
              if (this.IsStratified)
                idx = false (this.NumObservations, 1);
                idx(! this.missidx) = ! this.indices;
              else
                idx = ! this.indices;
              endif
            case "resubstitution" # no stratification
              idx = true (this.NumObservations, 1);
          endswitch
          return
        elseif (isempty (i))
          i = 1;
        endif
      else
        i = 1;
      endif

      if (! (isvector (i) && isnumeric (i) &&
             all (fix (i) == i) && all (i > 0)))
        error (strcat ("cvpartition.training: set index must", ...
                       " be a positive integer vector."));
      elseif (any (i > this.NumTestSets))
        error ("cvpartition.training: set index exceeds 'NumTestSets'.");
      endif

      switch (this.Type)
        case  "kfold"
          if (isscalar (i))
            if (this.IsStratified || this.IsGrouped)
              idx = false (this.NumObservations, 1);
              idx(! this.missidx) = this.indices != i;
            else
              idx = this.indices != i;
            endif
          else
            idx = logical ([]);
            for j = i
              if (this.IsStratified || this.IsGrouped)
                cid = false (this.NumObservations, 1);
                cid(! this.missidx) = this.indices != i;
              else
                cid = this.indices != i;
              endif
              idx = [idx, cid];
            endfor
          endif
        case "leaveout" # no stratification
          if (isscalar (i))
            idx = true (this.NumObservations, 1);
            idx(i) = false;
          else
            idx = logical ([]);
            for j = i
              new = true (this.NumObservations, 1);
              new(j) = false;
              idx = [idx, new];
            endfor
          endif
        case "holdout"
          if (this.IsStratified)
            idx = false (this.NumObservations, 1);
            idx(! this.missidx) = ! this.indices;
          else
            idx = ! this.indices;
          endif
        case "resubstitution" # no stratification
          idx = true (this.NumObservations, 1);
      endswitch

    endfunction

  endmethods

endclassdef

## Test output results for custom partition
%!test
%! custom = [1, 1, 1, 2, 2, 2, 1, 2, 3, 2, 3, 3, 2, 1, 3]';
%! cv = cvpartition ('CustomPartition', custom);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 15);
%! assert (cv.NumTestSets, 3);
%! assert (cv.TrainSize, [10, 9, 11]);
%! assert (cv.TestSize, [5, 6, 4]);
%! assert (cv.IsCustom, true);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! idx = training (cv, 1);
%! assert (idx, custom != 1);
%! idx = test (cv, 1);
%! assert (idx, custom == 1);
%! idx = training (cv, 2);
%! assert (idx, custom != 2);
%! idx = test (cv, 2);
%! assert (idx, custom == 2);
%! idx = training (cv, 3);
%! assert (idx, custom != 3);
%! idx = test (cv, 3);
%! assert (idx, custom == 3);
%! idx1 = training (cv, 'all');
%! idx2 = test (cv, 'all');
%! assert (idx1, ! idx2);
%!test
%! custom = logical ([1, 1, 1, 0, 0, 0, 1, 0, 1, 1])';
%! cv = cvpartition ('CustomPartition', custom);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 4);
%! assert (cv.TestSize, 6);
%! assert (cv.IsCustom, true);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! idx = training (cv, 1);
%! assert (idx, custom != 1);
%! assert (idx, training (cv, 'all'));
%! idx = test (cv, 1);
%! assert (idx, custom == 1);
%! assert (idx, test (cv, 'all'));
%!test
%! custom = logical ([1, 0, 0; 0, 1, 0; 1, 0, 0; 0, 0, 1]);
%! cv = cvpartition ('CustomPartition', custom);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 4);
%! assert (cv.NumTestSets, 3);
%! assert (cv.TrainSize, [2, 3, 3]);
%! assert (cv.TestSize, [2, 1, 1]);
%! assert (cv.IsCustom, true);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! idx = training (cv, 1);
%! assert (idx, custom(:,1) == false);
%! idx = test (cv, 1);
%! assert (idx, custom(:,1) == true);
%! idx = training (cv, 2);
%! assert (idx, custom(:,2) == false);
%! idx = test (cv, 2);
%! assert (idx, custom(:,2) == true);
%! assert (! custom, training (cv, 'all'));
%! assert (custom, test (cv, 'all'));
%!test
%! cv = cvpartition ('CustomPartition', [1:8]);
%! assert (cv.Type, 'leaveout');
%! assert (cv.NumObservations, 8);
%! assert (cv.NumTestSets, 8);
%! assert (cv.TrainSize, [7, 7, 7, 7, 7, 7, 7, 7]);
%! assert (cv.TestSize, [1, 1, 1, 1, 1, 1, 1, 1]);
%! assert (cv.IsCustom, true);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 7);
%! assert (sum (training (cv, 'all')), cv.TrainSize);
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 1);
%! assert (sum (test (cv, 'all')), cv.TestSize);
%! assert (! training (cv, 'all'), test (cv, 'all'));
%!test
%! cv = cvpartition ('CustomPartition', logical (eye (8)));
%! assert (cv.Type, 'leaveout');
%! assert (cv.NumObservations, 8);
%! assert (cv.NumTestSets, 8);
%! assert (cv.TrainSize, [7, 7, 7, 7, 7, 7, 7, 7]);
%! assert (cv.TestSize, [1, 1, 1, 1, 1, 1, 1, 1]);
%! assert (cv.IsCustom, true);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 7);
%! assert (sum (training (cv, 'all')), cv.TrainSize);
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 1);
%! assert (sum (test (cv, 'all')), cv.TestSize);
%! assert (! training (cv, 'all'), test (cv, 'all'));

## Test output results for scalar input N
%!test
%! cv = cvpartition (10, 'resubstitution');
%! assert (cv.Type, 'resubstitution');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 10);
%! assert (cv.TestSize, 10);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 10);
%! assert (training (cv, 'all'), logical (ones (10, 1)));
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 10);
%! assert (test (cv, 'all'), logical (ones (10, 1)));
%! assert (test (cv), training (cv));
%!test
%! cv = cvpartition (10, 'leaveout');
%! assert (cv.Type, 'leaveout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 10);
%! assert (cv.TrainSize, ones (1, 10) * 9);
%! assert (cv.TestSize, ones (1, 10));
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 9);
%! assert (training (cv, 'all'), ! logical (eye (10)));
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 1);
%! assert (test (cv, 'all'), logical (eye (10)));
%! assert (test (cv), ! training (cv));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%!test
%! rand ('seed', 5);  # for reproducibility
%! cv = cvpartition (10, 'holdout', 0.3);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 7);
%! assert (cv.TestSize, 3);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 7);
%! assert (training (cv, 'all'), logical ([1, 0, 1, 1, 0, 1, 1, 1, 0, 1])');
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 3);
%! assert (test (cv, 'all'), logical ([0, 1, 0, 0, 1, 0, 0, 0, 1, 0])');
%! assert (test (cv), ! training (cv));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%!test
%! cv = cvpartition (10, 'holdout', 4);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 6);
%! assert (cv.TestSize, 4);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (class (training (cv, 1)), 'logical');
%! assert (sum (training (cv, 1)), 6);
%! assert (class (test (cv, 1)), 'logical');
%! assert (sum (test (cv, 1)), 4);
%! assert (test (cv), ! training (cv));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%!test
%! cv = cvpartition (5, 'holdout', 4);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 5);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 1);
%! assert (cv.TestSize, 4);
%! assert (sum (test (cv, 1)), 4);
%!test
%! cv = cvpartition (5, 'holdout', 1);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 5);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 4);
%! assert (cv.TestSize, 1);
%! assert (sum (test (cv, 1)), 1);
%!test
%! cv = cvpartition (5, 'kfold');
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 5);
%! assert (cv.NumTestSets, 5);
%!test
%! cv = cvpartition (20, 'kfold');
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 20);
%! assert (cv.NumTestSets, 10);
%!test
%! cv = cvpartition (10, 'kfold', 5);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 5);
%! assert (cv.TrainSize, [8, 8, 8, 8, 8]);
%! assert (cv.TestSize, [2, 2, 2, 2, 2]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [10, 5]);
%!test
%! grpvar = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 5, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 5);
%! assert (cv.TrainSize, [10, 10, 10, 8, 10]);
%! assert (cv.TestSize, [2, 2, 2, 4, 2]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 5]);
%! assert (sum (test (cv, 'all')), [2, 2, 2, 4, 2]);
%! assert (sum (training (cv, 'all')), [10, 10, 10, 8, 10]);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 3, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 3);
%! assert (cv.TrainSize, [9, 10, 5]);
%! assert (cv.TestSize, [3, 2, 7]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 3]);
%! assert (sum (test (cv, 'all')), [3, 2, 7]);
%! assert (sum (training (cv, 'all')), [9, 10, 5]);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 2, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [6, 6]);
%! assert (cv.TestSize, [6, 6]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 2]);
%! assert (sum (test (cv, 'all')), [6, 6]);
%! assert (sum (training (cv, 'all')), [6, 6]);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, NaN, 2, 3, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 2, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [6, 5]);
%! assert (cv.TestSize, [5, 6]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! idx = ! isnan (grpvar);
%! assert (test (cv, 1)(idx), ! training (cv, 1)(idx));
%! assert (test (cv, 'all')(idx, :), ! training (cv, 'all')(idx, :));
%! assert (size (test (cv, 'all')), [12, 2]);
%! assert (sum (test (cv, 'all')), [5, 6]);
%! assert (sum (training (cv, 'all')), [6, 5]);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 2, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [5, 7]);
%! assert (cv.TestSize, [7, 5]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 2]);
%! assert (sum (test (cv, 'all')), [7, 5]);
%! assert (sum (training (cv, 'all')), [5, 7]);
%! assert (test (cv, 1)', grpvar == 2);
%! assert (test (cv, 2)', grpvar != 2);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 2, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [7, 5]);
%! assert (cv.TestSize, [5, 7]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 2]);
%! assert (sum (test (cv, 'all')), [5, 7]);
%! assert (sum (training (cv, 'all')), [7, 5]);
%! assert (test (cv, 1)', grpvar == 2);
%! assert (test (cv, 2)', grpvar != 2);
%!test
%! grpvar = [1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3];
%! rand ('seed', 5);
%! cv = cvpartition (12, 'kfold', 2, 'GroupingVariables', grpvar);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 12);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [7, 5]);
%! assert (cv.TestSize, [5, 7]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (size (test (cv, 'all')), [12, 2]);
%! assert (sum (test (cv, 'all')), [5, 7]);
%! assert (sum (training (cv, 'all')), [7, 5]);
%! assert (test (cv, 1)', grpvar == 3);
%! assert (test (cv, 2)', grpvar != 3);
%!test
%! status = warning;
%! warning ('off');
%! cv = cvpartition (5, 'kfold', 5, 'GroupingVariables', {'a';'a';'b';'b';''});
%! warning (status);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 5);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [2, 2]);
%! assert (cv.TestSize, [2, 2]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, true);
%! assert (cv.IsStratified, false);
%! idx = ! ismissing ({'a';'a';'b';'b';''});
%! assert (test (cv, 1)(idx), ! training (cv, 1)(idx));
%! assert (test (cv, 'all')(idx,:), ! training (cv, 'all')(idx,:));
%! assert (size (test (cv, 'all')), [5, 2]);
%! assert (sum (test (cv, 'all')), [2, 2]);
%! assert (sum (test (cv, 'all'), 2), [1; 1; 1; 1; 0]);

## Test output results for vector input X
%!test
%! rand ('seed', 5);
%! cv = cvpartition ([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], 'holdout', 3);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 7);
%! assert (cv.TestSize, 3);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, true);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv), logical ([0, 0, 0, 0, 1, 0, 1, 0, 0, 1])');
%!test
%! cv = cvpartition ([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], 'holdout', 4);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 6);
%! assert (cv.TestSize, 4);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, true);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (sum (test (cv)(1:5)), 2);
%! assert (sum (test (cv)(6:10)), 2);
%!test
%! grpvar = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2];
%! rand ('seed', 5);
%! cv = cvpartition (grpvar, 'holdout', 4, 'Stratify', false);
%! assert (cv.Type, 'holdout');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 1);
%! assert (cv.TrainSize, 6);
%! assert (cv.TestSize, 4);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (sum (test (cv)(1:5)), 3);
%! assert (sum (test (cv)(6:10)), 1);
%!test
%! cv = cvpartition ([1 1 1 1 1 2 2 2 2 1], 'kfold', 2);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [5, 5]);
%! assert (cv.TestSize, [5, 5]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, true);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (sum (test (cv, 1)(1:5)), 3);
%! assert (sum (test (cv, 2)(1:5)), 2);
%! assert (sum (test (cv, 1)(6:10)), 2);
%! assert (sum (test (cv, 2)(6:10)), 3);
%!test
%! grpvar = [1 1 1 1 1 2 2 2 2 1];
%! rand ('seed', 5);
%! cv = cvpartition (grpvar, 'kfold', 2, 'Stratify', false);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 10);
%! assert (cv.NumTestSets, 2);
%! assert (cv.TrainSize, [5, 5]);
%! assert (cv.TestSize, [5, 5]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, false);
%! assert (test (cv, 1), ! training (cv, 1));
%! assert (test (cv, 'all'), ! training (cv, 'all'));
%! assert (sum (test (cv, 1)(1:5)), 4);
%! assert (sum (test (cv, 2)(1:5)), 1);
%! assert (sum (test (cv, 1)(6:10)), 1);
%! assert (sum (test (cv, 2)(6:10)), 4);
%!test
%! status = warning;
%! warning ('off');
%! cv = cvpartition ({'a','a','b','b',''}, 'kfold');
%! warning (status);
%! assert (cv.Type, 'kfold');
%! assert (cv.NumObservations, 5);
%! assert (cv.NumTestSets, 4);
%! assert (cv.TrainSize, [3, 3, 3, 3]);
%! assert (cv.TestSize, [1, 1, 1, 1]);
%! assert (cv.IsCustom, false);
%! assert (cv.IsGrouped, false);
%! assert (cv.IsStratified, true);
%! idx = ! ismissing ({'a','a','b','b',''});
%! assert (test (cv, 1)(idx), ! training (cv, 1)(idx));
%! assert (test (cv, 'all')(idx,:), ! training (cv, 'all')(idx,:));
%! assert (sum (test (cv, 'all'), 2), [1; 1; 1; 1; 0]);

## Test input validation
%!error <cvpartition: too few input arguments.> cvpartition (2)
%!error <cvpartition: too many input arguments.> cvpartition (1, 2, 3, 4, 5, 6)
%!error <cvpartition: TESTSETS must be numeric of logical.> ...
%! cvpartition ("CustomPartition", 'a')
%!error <cvpartition: TESTSETS must be a numeric vector.> ...
%! cvpartition ("CustomPartition", [2, 3; 2, 3])
%!error <cvpartition: TESTSETS must be a logical vector or matrix.> ...
%! cvpartition ("CustomPartition", false (3, 3, 3))
%!error <cvpartition: each observation in TESTSETS must be exactly one in each row.> ...
%! cvpartition ("CustomPartition", [false, true; true, true; true, false])
%!error <cvpartition: a logical matrix in TESTSETS must not have more columns that rows.> ...
%! cvpartition ("CustomPartition", false (3, 5))
%!error <cvpartition: X must be a scalar positive integer value.> ...
%! cvpartition (-20, "LeaveOut")
%!error <cvpartition: X must be a scalar positive integer value.> ...
%! cvpartition (20.5, "LeaveOut")
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition (20, "HoldOut", [0.2, 0.3])
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition (20, "HoldOut", 'a')
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\).> ...
%! cvpartition (20, "HoldOut", 0)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\).> ...
%! cvpartition (20, "HoldOut", -0.1)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\).> ...
%! cvpartition (20, "HoldOut", 21)
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition (20, "kfold", [2, 3])
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition (20, "kfold", 'a')
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\].> ...
%! cvpartition (20, "kfold", 2.5)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\].> ...
%! cvpartition (20, "kfold", 21)
%!error <cvpartition: invalid optional paired argument for 'GroupingVariables'.> ...
%! cvpartition (10, "kfold", 3, "Group")
%!error <cvpartition: missing value for optional paired argument 'GroupingVariables'.> ...
%! cvpartition (10, "kfold", 3, "GroupingVariables")
%!error <cvpartition: invalid value for optional paired argument 'GroupingVariables'.> ...
%! cvpartition (10, "kfold", 3, "GroupingVariables", ones (3, 3, 3))
%!error <cvpartition: grouping variable does not match the number of observations.> ...
%! cvpartition (10, "kfold", 3, "GroupingVariables", {'a', 'a', 'a', 'b', 'b'})
%!warning <cvpartition: number of folds K is greater than the groups in 'GroupingVariables'. K is set to the number of groups.> ...
%! cvpartition (5, "kfold", 3, "GroupingVariables", {'a', 'a', 'a', 'b', 'b'});
%!error <cvpartition: invalid optional paired argument.> ...
%! cvpartition (20, "some")
%!error <cvpartition: invalid optional paired argument for stratification.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 2, "strat")
%!error <cvpartition: missing value for optional paired argument 'stratify'.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 2, "stratify")
%!error <cvpartition: invalid value for optional paired argument 'stratify'.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 2, "stratify", [true, true])
%!error <cvpartition: invalid value for optional paired argument 'stratify'.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 2, "stratify", 'no')
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 'a')
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 'a', "stratify", true)
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", [0.2, 0.3])
%!error <cvpartition: P value for 'holdout' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", [0.2, 0.3], "stratify", true)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 0)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 0, "stratify", true)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", -0.1)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", -0.1, "stratify", true)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 1.2)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 1.2, "stratify", false)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 6)
%!error <cvpartition: P value for 'holdout' must be a scalar in the range \(0,1\) or an integer scalar in the range \[1, N\), where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "holdout", 6, "stratify", false)
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 'a')
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 'a', "stratify", true)
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", [2, 3])
%!error <cvpartition: K value for 'kfold' must be a numeric scalar.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", [2, 3], "stratify", false)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 0)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 0, "stratify", true)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 1.5)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 1.5, "stratify", true)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 6)
%!error <cvpartition: K value for 'kfold' must be an integer scalar in the range \[1, N\], where N is the number of nonmissing observations in X.> ...
%! cvpartition ([1, 1, 1, 2, 2], "kfold", 6, "stratify", true)
%!error <cvpartition: invalid optional paired argument.> ...
%! cvpartition ([1, 1, 1, 2, 2], "leaveout")
%!error <cvpartition: invalid optional paired argument.> ...
%! cvpartition ([1, 1, 1, 2, 2], "resubstitution")
%!error <cvpartition: invalid optional paired argument.> ...
%! cvpartition ([1, 1, 1, 2, 2], "some")
%!error <cvpartition: invalid first input argument.> ...
%! cvpartition ({1, 1; 2, 2}, "kfold")

%!error <cvpartition.repartition: cannot repartition a custom partition.> ...
%! repartition (cvpartition ('CustomPartition', [1,1,2,2,3,3]))
%!error <cvpartition.repartition: 'legacy' flag does not apply to stratified or grouped 'kfold' partitioned objects.> ...
%! repartition (cvpartition ([1 1 1 1 1 2 2 2 2 1], 'kfold', 2, 'Stratify', true), 'legacy')
%!error <cvpartition.repartition: 'legacy' flag is only valid for 'kfold' partitioned objects.> ...
%! repartition (cvpartition (20, 'Leaveout', 0.2), 'legacy')
%!error <cvpartition.repartition: SVAL must be a real scalar or vector.> ...
%! repartition (cvpartition (20, 'Leaveout', 0.2), 'asd')
%!error <cvpartition.repartition: SVAL must be a real scalar or vector.> ...
%! repartition (cvpartition (20, 'Leaveout', 0.2), 2+i)
%!error <cvpartition.repartition: SVAL must be a real scalar or vector.> ...
%! repartition (cvpartition (20, 'KFold', 5), [34, 56; 2, 3])

%!error <cvpartition.test: too many input arguments.> ...
%! test (cvpartition (20, "kfold"), 2, 3)
%!error <cvpartition.test: set index must be a positive integer vector.> ...
%! test (cvpartition (20, "kfold"), 0)
%!error <cvpartition.test: set index must be a positive integer vector.> ...
%! test (cvpartition (20, "kfold"), 1.5)
%!error <cvpartition.test: set index must be a positive integer vector.> ...
%! test (cvpartition (20, "kfold"), [1, 1.5])
%!error <cvpartition.test: set index must be a positive integer vector.> ...
%! test (cvpartition (20, "kfold"), [2, 3; 2, 3])
%!error <cvpartition.test: set index exceeds 'NumTestSets'.> ...
%! test (cvpartition (20, "kfold"), 21)
%!error <cvpartition.test: set index exceeds 'NumTestSets'.> ...
%! test (cvpartition (20, "kfold"), [18, 21])

%!error <cvpartition.training: too many input arguments.> ...
%! training (cvpartition (20, "kfold"), 2, 3)
%!error <cvpartition.training: set index must be a positive integer vector.> ...
%! training (cvpartition (20, "kfold"), 0)
%!error <cvpartition.training: set index must be a positive integer vector.> ...
%! training (cvpartition (20, "kfold"), 1.5)
%!error <cvpartition.training: set index must be a positive integer vector.> ...
%! training (cvpartition (20, "kfold"), [1, 1.5])
%!error <cvpartition.training: set index must be a positive integer vector.> ...
%! training (cvpartition (20, "kfold"), [2, 3; 2, 3])
%!error <cvpartition.training: set index exceeds 'NumTestSets'.> ...
%! training (cvpartition (20, "kfold"), 21)
%!error <cvpartition.training: set index exceeds 'NumTestSets'.> ...
%! training (cvpartition (20, "kfold"), [18, 21])

## Test 'summary' method
%!test
%! ## 1. Stratified K-Fold: Basic Text Labels
%! species = [repmat({"Setosa"}, 10, 1); repmat({"Versicolor"}, 10, 1)];
%! rand ("state", 42);
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! assert (height (T), 10);
%! assert (all (ismember ({"Set", "SetSize", "StratificationLabel", ...
%!                        "StratificationCount", "PercentInSet"}, ...
%!                        T.Properties.VariableNames)));
%!
%! ## Check Output Type (String Array) and Counts
%! if (exist ("string", "class"))
%!   assert (isa (T.Set, "string"));
%!   assert (isa (T.StratificationLabel, "string"));
%!   mask = (T.Set == "all") & (T.StratificationLabel == "Setosa");
%! else
%!   ## Fallback for older environments
%!   mask = strcmp (T.Set, "all") & strcmp (T.StratificationLabel, "Setosa");
%! endif
%! assert (T.StratificationCount(mask), 10);
%!test
%! ## 2. Grouped K-Fold: Basic Numeric Labels
%! groups = [1; 1; 1; 2; 2; 3; 3; 3; 3; 3];
%! rand ("state", 100);
%! c = cvpartition (numel (groups), "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! assert (any (strcmp ("GroupLabel", T.Properties.VariableNames)));
%!
%! ## Verify Group Integrity
%! if (iscell (T.GroupLabel))
%!   vals = cell2mat (T.GroupLabel);
%! else
%!   vals = T.GroupLabel;
%! endif
%! mask_g3 = (vals == 3);
%!
%! if (exist ("string", "class"))
%!   mask_t1 = (T.Set == "test1");
%! else
%!   mask_t1 = strcmp (T.Set, "test1");
%! endif
%!
%! count_g3 = T.GroupCount(mask_g3 & mask_t1);
%! assert (count_g3 == 5 || count_g3 == 0);
%!test
%! ## 3. Grouped K-Fold: Matrix Grouping
%! g1 = [1; 1; 1; 2; 2; 2];
%! g2 = [1; 1; 2; 1; 2; 2];
%! groups = [g1, g2];
%! c = cvpartition (6, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! ## 4 unique groups * 5 sets (all + 2 train + 2 test)
%! assert (height (T), 20);
%!test
%! ## 4. Stratified Holdout: Basic
%! species = [repmat({"A"}, 10, 1); repmat({"B"}, 10, 1)];
%! c = cvpartition (species, "Holdout", 0.5);
%! T = summary (c);
%! sets = unique (T.Set);
%! assert (numel (sets), 3); ## all, train1, test1
%!test
%! ## 5. Mathematical Consistency: Percentages
%! classes = [1; 1; 2; 2; 3; 3];
%! c = cvpartition (classes, "KFold", 2);
%! T = summary (c);
%!
%! if (exist ("string", "class"))
%!   mask_all = (T.Set == "all");
%!   mask_tr1 = (T.Set == "train1");
%! else
%!   mask_all = strcmp (T.Set, "all");
%!   mask_tr1 = strcmp (T.Set, "train1");
%! endif
%!
%! assert (sum (T.PercentInSet(mask_all)), 100, 1e-10);
%! assert (sum (T.PercentInSet(mask_tr1)), 100, 1e-10);
%!test
%! ## 6. Mathematical Consistency: Set Sizes
%! N = 20;
%! c = cvpartition (ones (N, 1), "KFold", 4);
%! T = summary (c);
%!
%! if (exist ("string", "class"))
%!   mask_tr1 = (T.Set == "train1");
%!   mask_ts1 = (T.Set == "test1");
%! else
%!   mask_tr1 = strcmp (T.Set, "train1");
%!   mask_ts1 = strcmp (T.Set, "test1");
%! endif
%!
%! size_tr1 = T.SetSize(find (mask_tr1, 1));
%! size_ts1 = T.SetSize(find (mask_ts1, 1));
%! assert (size_tr1 + size_ts1, N);
%!test
%! ## 7. Logical Grouping Variables
%! groups = [true; true; true; false; false];
%! c = cvpartition (5, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! assert (height (T), 2 * 5);
%! if (iscell (T.GroupLabel))
%!   u_labels = unique (cell2mat (T.GroupLabel));
%! else
%!   u_labels = unique (T.GroupLabel);
%! endif
%! assert (numel (u_labels), 2);
%!test
%! ## 8. Char Array Grouping Variables
%! groups = ['A'; 'A'; 'B'; 'B'; 'C'];
%! c = cvpartition (5, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! assert (height (T), 3 * 5);
%! assert (any (strcmp ("GroupLabel", T.Properties.VariableNames)));
%!test
%! ## 9. Floating Point Grouping Variables
%! groups = [1.1; 1.1; 2.2; 2.2];
%! c = cvpartition (4, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! if (iscell (T.GroupLabel))
%!   vals = cell2mat (T.GroupLabel);
%! else
%!   vals = T.GroupLabel;
%! endif
%! assert (any (abs (vals - 1.1) < 1e-10));
%! assert (any (abs (vals - 2.2) < 1e-10));
%!test
%! ## 10. Negative Numeric Grouping
%! groups = [-5; -5; -10; -10];
%! c = cvpartition (4, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! assert (height (T), 2 * 5);
%!test
%! ## 11. Missing Values in Stratification (NaN)
%! classes = [1; 1; 2; 2; NaN; NaN];
%! c = cvpartition (classes, "KFold", 2);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   mask_all = (T.Set == "all");
%! else
%!   mask_all = strcmp (T.Set, "all");
%! endif
%! total_obs = T.SetSize(find (mask_all, 1));
%! assert (total_obs, 4);
%!test
%! ## 12. Missing Values in Grouping (NaN)
%! groups = [1; 1; 2; 2; NaN];
%! c = cvpartition (5, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   mask_all = (T.Set == "all");
%! else
%!   mask_all = strcmp (T.Set, "all");
%! endif
%! assert (T.SetSize(find (mask_all, 1)), 4);
%!test
%! ## 13. Unbalanced Stratification
%! species = [repmat({"C1"}, 90, 1); repmat({"C2"}, 10, 1)];
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   mask_ts1 = (T.Set == "test1");
%!   subT = T(mask_ts1, :);
%!   c1_count = subT.StratificationCount(subT.StratificationLabel == "C1");
%!   c2_count = subT.StratificationCount(subT.StratificationLabel == "C2");
%! else
%!   mask_ts1 = strcmp (T.Set, "test1");
%!   subT = T(mask_ts1, :);
%!   c1_count = subT.StratificationCount(strcmp (subT.StratificationLabel, "C1"));
%!   c2_count = subT.StratificationCount(strcmp (subT.StratificationLabel, "C2"));
%! endif
%! assert (c1_count == 45);
%! assert (c2_count == 5);
%!test
%! ## 14. Single Observation per Group (Edge Case)
%! groups = [1; 2; 3; 4];
%! c = cvpartition (4, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   mask_ts1 = (T.Set == "test1");
%! else
%!   mask_ts1 = strcmp (T.Set, "test1");
%! endif
%! counts = T.GroupCount(mask_ts1);
%! assert (sum (counts == 1), 2);
%! assert (sum (counts == 0), 2);
%!test
%! ## 15. Set Name Generation Verification
%! species = [1; 1; 2; 2];
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! set_names = unique (T.Set);
%! expected = {"all"; "train1"; "test1"; "train2"; "test2"};
%! if (exist ("string", "class"))
%!   ## Convert string array to cell for sort comparison
%!   assert (sort (cellstr (set_names)), sort (expected));
%! else
%!   assert (sort (set_names), sort (expected));
%! endif
%!test
%! ## 16. Label Column Consistency
%! groups = ['A'; 'B'];
%! c = cvpartition (2, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   assert (isa (T.GroupLabel, "string"));
%! else
%!   assert (iscellstr (T.GroupLabel));
%! endif
%!test
%! ## 17. Valid "Blank" Labels (Space) - FIX APPLIED
%! species = {"A"; "A"; " "; " "};
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   labels = cellstr (T.StratificationLabel);
%!   sets = cellstr (T.Set);
%!   assert (any (strcmp (labels, " ")));
%!   mask_space = strcmp (labels, " ");
%!   mask_all = strcmp (sets, "all");
%! else
%!   assert (any (strcmp (T.StratificationLabel, " ")));
%!   mask_space = strcmp (T.StratificationLabel, " ");
%!   mask_all = strcmp (T.Set, "all");
%! endif
%! assert (sum (T.StratificationCount(mask_space & mask_all)), 2);
%!test
%! ## 18. Large K (Leave-One-Out Simulation) - FIX APPLIED
%! species = [1; 1; 2; 2];
%! warn_state = warning ("off", "all");
%! c = cvpartition (species, "KFold", 4);
%! warning (warn_state);
%! T = summary (c);
%! assert (height (T), 18);
%! if (exist ("string", "class"))
%!   mask_test = startsWith (cellstr(T.Set), "test");
%! else
%!   mask_test = strncmp (T.Set, "test", 4);
%! endif
%! assert (all (T.SetSize(mask_test) == 1));
%!test
%! ## 19. Repeated Holdout Integrity
%! species = [1; 1; 2; 2];
%! rand ("state", 42);
%! c = cvpartition (species, "Holdout", 0.5);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   mask_ts1 = (T.Set == "test1");
%! else
%!   mask_ts1 = strcmp (T.Set, "test1");
%! endif
%! size_ts1 = T.SetSize(find (mask_ts1, 1));
%! assert (size_ts1, 2);
%!test
%! ## 20. Empty String Handling (Missing Data)
%! species = {"A"; "A"; ""; ""};
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! if (exist ("string", "class"))
%!   assert (! any (T.StratificationLabel == ""));
%!   mask_all = (T.Set == "all");
%! else
%!   assert (! any (strcmp (T.StratificationLabel, "")));
%!   mask_all = strcmp (T.Set, "all");
%! endif
%! total_rows = T.SetSize(find (mask_all, 1));
%! assert (total_rows, 2);
%!test
%! ## 21. Basic Unstacking (Stratified K-Fold)
%! species = [repmat({"Alpha"}, 10, 1); repmat({"Beta"}, 10, 1)];
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! ## Check dimensions: 3 sets (all, train1, test1, etc) x (Set+SetSize + 2 Labels)
%! assert (height (T_wide), 5);
%! assert (width (T_wide), 4);
%! assert (all (ismember ({"Alpha", "Beta"}, T_wide.Properties.VariableNames)));
%!test
%! ## 22. Data Integrity Check (Row Sums)
%! species = [repmat({"Control"}, 20, 1); repmat({"Treatment"}, 80, 1)];
%! c = cvpartition (species, "Holdout", 0.25);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! row_sums = T_wide.Control + T_wide.Treatment;
%! assert (all (row_sums == T_wide.SetSize));
%!test
%! ## 23. Unstacking Grouped Data (Numeric Labels)
%! groups = [1; 1; 2; 2; 2];
%! c = cvpartition (5, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "GroupCount", "GroupLabel");
%! ## Check if numeric columns were created successfully
%! col_names = T_wide.Properties.VariableNames;
%! assert (any (cellfun (@(x) ~isempty (strfind (x, "1")), col_names)));
%! assert (any (cellfun (@(x) ~isempty (strfind (x, "2")), col_names)));
%!test
%! ## 24. Unstacking with Missing/NaN Groups
%! groups = [1; 1; 2; 2; NaN];
%! c = cvpartition (5, "KFold", 2, "GroupingVariables", groups);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "GroupCount", "GroupLabel");
%! ## Should only have columns for 1 and 2, not NaN or 'undefined'
%! assert (width (T_wide), 4); ## Set, SetSize, x1, x2
%!test
%! ## 25. Unstacking String Array Inputs
%! species = {"Red"; "Blue"; "Red"; "Blue"};
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! ## Verify input is actually string before unstacking checks
%! if (exist ("string", "class"))
%!   assert (isa (T.Set, "string"));
%! endif
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! ## Check the 'all' row count for Red
%! assert (T_wide.Red(strcmp(cellstr(T_wide.Set), "all")) == 2);
%!test
%! ## 26. Large K Unstacking (Many Rows)
%! species = [repmat({"High"}, 10, 1); repmat({"Low"}, 10, 1)];
%! c = cvpartition (species, "KFold", 10);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! ## 10 folds * 2 (train/test) + 1 (all) = 21 rows
%! assert (height (T_wide), 21);
%!test
%! ## 27. Unstacking with Special Characters in Labels
%! species = {"Type A"; "Type A"; "Type-B"; "Type-B"};
%! c = cvpartition (species, "KFold", 2);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! vnames = T_wide.Properties.VariableNames;
%! ## Check if spaces/dashes were handled/preserved in some valid form
%! assert (numel (vnames), 4);
%!test
%! ## 28. Verification of 'all' row logic after Unstacking
%! species = [repmat({"Yes"}, 50, 1); repmat({"No"}, 50, 1)];
%! c = cvpartition (species, "Holdout", 0.2);
%! T = summary (c);
%! T_wide = unstack (T(:, 1:4), "StratificationCount", "StratificationLabel");
%! mask = strcmp (cellstr (T_wide.Set), "all");
%! assert (T_wide.Yes(mask) == 50);
%! assert (T_wide.No(mask) == 50);
%!test
%! ## 29. Robustness against re-ordering
%! species = {"Left"; "Left"; "Right"; "Right"};
%! c = cvpartition (species, "Holdout", 0.5);
%! T = summary (c);
%! T_shuffled = T([3, 1, 2], :);
%! T_wide = unstack (T_shuffled(:, 1:4), "StratificationCount", "StratificationLabel");
%! mask = strcmp (cellstr (T_wide.Set), "all");
%! assert (T_wide.Left(mask) == 2);
%!error <cvpartition.summary: partition must be stratified or grouped.>
%! c = cvpartition (20, "KFold", 5);
%! summary (c);
%!error <cvpartition.summary: partition must be stratified or grouped.>
%! c = cvpartition (10, "LeaveOut");
%! summary (c);


