Creating a Custom Time-Series Cross-Validator in PySpark

02 July 2024

I often build robust machine learning pipelines in PySpark qua my job, and while the built-in machine learning library is very powerful, sometimes I find it lacking. Fortunately, we can extend it ourselves to suit our needs. Here, I show you how I built a custom time-series cross-validator in PySpark.

Jump to solution

Understanding Time-Series Cross-Validation

Unlike traditional k-fold cross-validation, which splits the data randomly into \(k\) folds, time-series cross-validation involves splitting the data into consecutive periods, such that the temporal order is respected. The training set of each fold consists of past data, while the validation set consists of more recent data. This approach mimics how the model would be used in practice for forecasting.

There are different ways of creating these folds in time-series cross-validation. Two of the most common ways are rolling windows and expanding windows. In the rolling window setup, the training sets have the same size as we let the beginning of the training set roll forward. On the other hand, the expanding window setup enables the training to expand as the start of the training set is kept constant.

time-series-train-val-splits
Comparison of expanding and rolling window train-validation splits for time-series data.

Implementing Time-Series Cross-Validation in PySpark

To implement time-series cross-validation in PySpark, we can extend the functionality of the CrossValidator class from pyspark.ml.tuning. By inspecting the PySpark source code, we see that the _fit method first creates the folds by calling the _kFold method. In particular, this is the code that creates the folds:

# PySpark version: 3.5.0

def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
    nFolds = self.getOrDefault(self.numFolds)
    foldCol = self.getOrDefault(self.foldCol)

    datasets = []
    if not foldCol:
        # Do random k-fold split.
        seed = self.getOrDefault(self.seed)
        h = 1.0 / nFolds
        randCol = self.uid + "_rand"
        df = dataset.select("*", rand(seed).alias(randCol))
        for i in range(nFolds):
            validateLB = i * h
            validateUB = (i + 1) * h
            condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
            validation = df.filter(condition)
            train = df.filter(~condition)
            datasets.append((train, validation))
    else:
      [...]
    return datasets

The method creates a new column with random numbers and uses it to create train-validation pairs by filtering on it. We can implement our custom time-series cross-validator by overriding this method.

Solution

from multiprocessing.pool import ThreadPool

from pyspark import keyword_only
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.tuning import CrossValidator

class tsCrossValidator(CrossValidator):
    """
    Custom validator for time-series cross-validation.

    This class extends the functionality of PySpark's CrossValidator to support
    walk-forward time-series cross-validation. It splits the dataset into
    consecutive periods with each fold using data from the past as training
    and the most recent period as validation.

    In particular, it overrides the _kFold method (which is used in the fit method)
    """
    datetimeCol = Param(
        Params._dummy(), 
        "datetimeCol", 
        "Column name for splitting the data",
        typeConverter=TypeConverters.toString)
    
    timeSplit = Param(
        Params._dummy(), 
        "timeSplit", 
        "Length of time to leave in validation set. Should be some sort of timedelta or relativedelta")
    
    gap = Param(
        Params._dummy(), 
        "gap", 
        "Length of time to leave bas gap between train and validation")
    
    disableExpandingWindow = Param(
        Params._dummy(),
        "disableExpandingWindow",
        "Boolean for disabling expanding window folds and taking rolling windows instead.",
        typeConverter=TypeConverters.toBoolean)

    @keyword_only
    def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None,
                 numFolds=3, datetimeCol = 'date', timeSplit=None, 
                 gap=None, disableExpandingWindow=False, parallelism=1, collectSubModels=False):

        super(tsCrossValidator, self).__init__(
            estimator=estimator, 
            estimatorParamMaps=estimatorParamMaps, 
            evaluator=evaluator, 
            numFolds=numFolds,
            parallelism=parallelism, 
            collectSubModels=collectSubModels
        )
       
        self._setDefault(gap=None, datetimeCol='date', timeSplit=None, disableExpandingWindow=False)

        # Explicitly set the provided values
        self._set(gap=gap, datetimeCol=datetimeCol, timeSplit=timeSplit, disableExpandingWindow=disableExpandingWindow)

        kwargs = self._input_kwargs
        self._set(**kwargs)
    
    def getDatetimeCol(self):
        return self.getOrDefault(self.datetimeCol)
    
    def setDatetimeCol(self, datetimeCol):
        return self._set(datetimeCol=datetimeCol)
    
    def getTimeSplit(self):
        return self.getOrDefault(self.timeSplit)
    
    def setTimeSplit(self, timeSplit):
        return self._set(timeSplit=timeSplit)
    
    def getDisableExpandingWindow(self):
        return self.getOrDefault(self.disableExpandingWindow)
    
    def setDisableExpandingWindow(self, disableExpandingWindow):
        return self._set(disableExpandingWindow=disableExpandingWindow)
    
    def getGap(self):
        return self.getOrDefault(self.gap)

    def setGap(self, gap):
        return self._set(gap=gap)

    def _kFold(self, dataset):
        nFolds = self.getOrDefault(self.numFolds)
        datetimeCol = self.getOrDefault(self.datetimeCol)
        timeSplit = self.getOrDefault(self.timeSplit)
        gap = self.getOrDefault(self.gap)
        disableExpandingWindow = self.getOrDefault(self.disableExpandingWindow)

        datasets = []
        endDate = dataset.agg({datetimeCol : 'max'}).collect()[0][0]
        trainLB = dataset.agg({datetimeCol: 'min'}).collect()[0][0]
        for i in reversed(range(nFolds)):
            validateUB = endDate - i * timeSplit
            validateLB = endDate - (i + 1) * timeSplit
            trainUB = validateLB - gap if gap is not None else validateLB

            val_condition = (dataset[datetimeCol] > validateLB) & (dataset[datetimeCol] <= validateUB)
            train_condition = (dataset[datetimeCol] <= trainUB) & (dataset[datetimeCol] >= trainLB)

            validation = dataset.filter(val_condition)
            train = dataset.filter(train_condition)

            datasets.append((train, validation))

            if disableExpandingWindow:
                trainLB += timeSplit
        
        return datasets

Conclusion

That’s it! You can replace your CrossValidator with the tsCrossValidator instead and fit your pipelines and tune your hyperparameters like you usually would.

Further reading

I found the following particularly helpful: