Wrapper to boost general additive models for each feature.
Source:R/boost_splines.R
boostSplines.Rd
This wrapper function automatically initializes the model by adding all numerical
features as spline base-learner. Categorical features are dummy encoded and inserted
using another linear base-learners without intercept. The function boostSplines
does also train the model.
The returned object is an object of the Compboost class. This object can be
used for further analyses (see ?Compboost
for details).
Usage
boostSplines(
data,
target,
optimizer = NULL,
loss = NULL,
learning_rate = 0.05,
iterations = 100,
trace = -1,
degree = 3,
n_knots = 20,
penalty = 2,
df = 0,
differences = 2,
data_source = InMemoryData,
oob_fraction = NULL,
bin_root = 0,
cache_type = "inverse",
stop_args = NULL,
df_cat = 1,
stop_time = "microseconds",
additional_risk_logs = list()
)
Arguments
- data
(
data.frame()
)
A data frame containing the data.- target
(
character(1)
| ResponseRegr | ResponseBinaryClassif)
Character value containing the target variable or response object. Note that the loss must match the data type of the target.- optimizer
(OptimizerCoordinateDescent | OptimizerCoordinateDescentLineSearch | OptimizerAGBM | OptimizerCosineAnnealing)
An initializedS4
optimizer object (requires to callOptimizer*.new(..)
. See the respective help page for further information.- loss
(LossQuadratic | LossBinomial | LossHuber | LossAbsolute | LossQuantile)
An initializedS4
loss object (requires to callLoss*$new(...)
). See the respective help page for further information.- learning_rate
(
numeric(1)
)
Learning rate to shrink the parameter in each step.- iterations
(
integer(1)
)
Number of iterations that are trained. Ifiterations == 0
, the untrained object is returned. This can be useful if other base learners (e.g. an interaction via a tensor base learner) are added.- trace
(
integer(1)
)
Integer indicating how often a trace should be printed. Specifyingtrace = 10
, then every 10th iteration is printed. If no trace should be printed settrace = 0
. Default is -1 which means that in total 40 iterations are printed.- degree
(
integer(1)
)cr Polynomial degree of the splines.- n_knots
(
integer(1)
)
Number of equidistant "inner knots". The actual number of used knots does also depend on the polynomial degree.- penalty
(
numeric(1)
)
Penalty term for p-splines. If the penalty equals 0, then ordinary b-splines are fitted. The higher the penalty, the higher the smoothness.- df
(
numeric(1)
)
Degrees of freedom of the base learner(s).- differences
(
integer(1)
)
Number of differences that are used for penalization. The higher the difference, the higher the smoothness.- data_source
(
Data*
)
UninitializedData*
object which is used to store the data. At the moment just in memory training is supported.- oob_fraction
(
numeric(1)
)
Fraction of how much data are used to track the out of bag risk.- bin_root
(
integer(1)
)
The binning root to reduce the data to \(n^{1/\text{binroot}}\) data points (defaultbin_root = 1
, which means no binning is applied). A value ofbin_root = 2
is suggested for the best approximation error (cf. Wood et al. (2017) Generalized additive models for gigadata: modeling the UK black smoke network daily data).- cache_type
(
character(1)
)
String to indicate what method should be used to estimate the parameter in each iteration. Default iscache_type = "cholesky"
which computes the Cholesky decomposition, caches it, and reuses the matrix over and over again. The other option is to usecache_type = "inverse"
which does the same but caches the inverse.- stop_args
(
list(2)
)
List containing two elementspatience
andeps_for_break
which can be set to use early stopping on the left out data from settingoob_fraction
. If! is.null(stop_args)
, early stopping is triggered.- df_cat
(
numeric(1)
)
Degrees of freedom of the categorical base-learner.- stop_time
(
character(1)
)
Unit of measured time.- additional_risk_logs
(
list(Logger)
)
Additional logger passed to theCompboost
object.
Value
A model of the Compboost class. This model is an R6 object
which can be used for retraining, predicting, plotting, and anything described in
?Compboost
.
Examples
mod = boostSplines(data = iris, target = "Sepal.Length", loss = LossQuadratic$new(),
oob_fraction = 0.3)
#> 1/100 risk = 0.31 oob_risk = 0.32 time = 0
#> 2/100 risk = 0.28 oob_risk = 0.29 time = 101
#> 4/100 risk = 0.24 oob_risk = 0.26 time = 245
#> 6/100 risk = 0.21 oob_risk = 0.23 time = 374
#> 8/100 risk = 0.18 oob_risk = 0.2 time = 512
#> 10/100 risk = 0.15 oob_risk = 0.18 time = 634
#> 12/100 risk = 0.13 oob_risk = 0.17 time = 754
#> 14/100 risk = 0.12 oob_risk = 0.15 time = 872
#> 16/100 risk = 0.11 oob_risk = 0.14 time = 993
#> 18/100 risk = 0.095 oob_risk = 0.13 time = 1114
#> 20/100 risk = 0.087 oob_risk = 0.13 time = 1234
#> 22/100 risk = 0.08 oob_risk = 0.12 time = 1355
#> 24/100 risk = 0.074 oob_risk = 0.11 time = 1501
#> 26/100 risk = 0.069 oob_risk = 0.11 time = 1639
#> 28/100 risk = 0.066 oob_risk = 0.11 time = 1758
#> 30/100 risk = 0.062 oob_risk = 0.11 time = 1932
#> 32/100 risk = 0.06 oob_risk = 0.11 time = 2059
#> 34/100 risk = 0.058 oob_risk = 0.1 time = 2198
#> 36/100 risk = 0.056 oob_risk = 0.1 time = 2319
#> 38/100 risk = 0.055 oob_risk = 0.1 time = 2440
#> 40/100 risk = 0.053 oob_risk = 0.1 time = 2569
#> 42/100 risk = 0.052 oob_risk = 0.1 time = 2705
#> 44/100 risk = 0.051 oob_risk = 0.1 time = 2816
#> 46/100 risk = 0.05 oob_risk = 0.1 time = 2925
#> 48/100 risk = 0.049 oob_risk = 0.1 time = 3038
#> 50/100 risk = 0.048 oob_risk = 0.1 time = 3150
#> 52/100 risk = 0.048 oob_risk = 0.1 time = 3261
#> 54/100 risk = 0.047 oob_risk = 0.1 time = 3370
#> 56/100 risk = 0.046 oob_risk = 0.1 time = 3480
#> 58/100 risk = 0.046 oob_risk = 0.1 time = 3591
#> 60/100 risk = 0.045 oob_risk = 0.099 time = 3702
#> 62/100 risk = 0.044 oob_risk = 0.099 time = 3814
#> 64/100 risk = 0.044 oob_risk = 0.099 time = 3925
#> 66/100 risk = 0.043 oob_risk = 0.099 time = 4044
#> 68/100 risk = 0.043 oob_risk = 0.099 time = 4156
#> 70/100 risk = 0.043 oob_risk = 0.099 time = 4267
#> 72/100 risk = 0.042 oob_risk = 0.099 time = 4379
#> 74/100 risk = 0.042 oob_risk = 0.098 time = 4490
#> 76/100 risk = 0.041 oob_risk = 0.098 time = 4603
#> 78/100 risk = 0.041 oob_risk = 0.098 time = 4718
#> 80/100 risk = 0.041 oob_risk = 0.098 time = 4942
#> 82/100 risk = 0.04 oob_risk = 0.098 time = 5077
#> 84/100 risk = 0.04 oob_risk = 0.098 time = 5210
#> 86/100 risk = 0.04 oob_risk = 0.098 time = 5352
#> 88/100 risk = 0.039 oob_risk = 0.098 time = 5484
#> 90/100 risk = 0.039 oob_risk = 0.097 time = 5629
#> 92/100 risk = 0.039 oob_risk = 0.097 time = 5762
#> 94/100 risk = 0.039 oob_risk = 0.097 time = 5893
#> 96/100 risk = 0.038 oob_risk = 0.097 time = 6039
#> 98/100 risk = 0.038 oob_risk = 0.097 time = 6153
#> 100/100 risk = 0.038 oob_risk = 0.097 time = 6266
#>
#>
#> Train 100 iterations in 0 Seconds.
#> Final risk based on the train set: 0.038
#>
mod$getBaselearnerNames()
#> [1] "Sepal.Width_spline" "Petal.Length_spline" "Petal.Width_spline"
#> [4] "Species_ridge"
mod$getEstimatedCoef()
#> Depricated, use `$getCoef()` instead.
#> $Petal.Length_spline
#> [,1]
#> [1,] -1.020038063
#> [2,] -0.842700185
#> [3,] -0.708210074
#> [4,] -0.616528258
#> [5,] -0.629992257
#> [6,] -0.674377781
#> [7,] -0.700732067
#> [8,] -0.697215039
#> [9,] -0.651967365
#> [10,] -0.537788304
#> [11,] -0.341686833
#> [12,] -0.177164903
#> [13,] 0.007505898
#> [14,] 0.290089474
#> [15,] 0.451293220
#> [16,] 0.400021292
#> [17,] 0.366344295
#> [18,] 0.483695287
#> [19,] 0.784027942
#> [20,] 1.123146019
#> [21,] 1.463578011
#> [22,] 1.656349917
#> [23,] 1.738770154
#> [24,] 1.806016019
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#>
#> $Petal.Width_spline
#> [,1]
#> [1,] -0.368363815
#> [2,] -0.230727829
#> [3,] -0.097569154
#> [4,] -0.016578666
#> [5,] -0.028837656
#> [6,] -0.057857155
#> [7,] -0.057668639
#> [8,] -0.028703497
#> [9,] -0.002385605
#> [10,] -0.010779152
#> [11,] -0.054068445
#> [12,] -0.040466735
#> [13,] 0.064127383
#> [14,] 0.136240247
#> [15,] 0.133221263
#> [16,] 0.102792184
#> [17,] 0.067773564
#> [18,] 0.110728014
#> [19,] 0.161049540
#> [20,] 0.160193965
#> [21,] 0.130363799
#> [22,] 0.046188478
#> [23,] -0.067301221
#> [24,] -0.181037323
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#>
#> $Sepal.Width_spline
#> [,1]
#> [1,] -0.391634041
#> [2,] -0.234533599
#> [3,] -0.086240363
#> [4,] 0.009214706
#> [5,] 0.007799400
#> [6,] -0.065038926
#> [7,] -0.104271259
#> [8,] -0.096040595
#> [9,] -0.054096040
#> [10,] -0.017548114
#> [11,] -0.012499394
#> [12,] 0.013963397
#> [13,] -0.016707110
#> [14,] -0.064197029
#> [15,] -0.059409731
#> [16,] 0.010539003
#> [17,] 0.075101623
#> [18,] 0.111945966
#> [19,] 0.157947620
#> [20,] 0.253521014
#> [21,] 0.357967518
#> [22,] 0.374894173
#> [23,] 0.371380710
#> [24,] 0.369954604
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#>
#> $offset
#> [1] 5.79619
#>
table(mod$getSelectedBaselearner())
#>
#> Petal.Length_spline Petal.Width_spline Sepal.Width_spline
#> 45 26 29
mod$predict()
#> [,1]
#> [1,] 5.001037
#> [2,] 4.971113
#> [3,] 4.890050
#> [4,] 5.003350
#> [5,] 5.410319
#> [6,] 4.975044
#> [7,] 4.956201
#> [8,] 4.891916
#> [9,] 5.129464
#> [10,] 5.001642
#> [11,] 4.859679
#> [12,] 4.715600
#> [13,] 5.228497
#> [14,] 5.077558
#> [15,] 5.303721
#> [16,] 5.262642
#> [17,] 5.223762
#> [18,] 5.012599
#> [19,] 5.036910
#> [20,] 4.939170
#> [21,] 4.994930
#> [22,] 5.029948
#> [23,] 5.266600
#> [24,] 5.342117
#> [25,] 5.003350
#> [26,] 4.842022
#> [27,] 4.958629
#> [28,] 4.942359
#> [29,] 4.928705
#> [30,] 4.975044
#> [31,] 5.035150
#> [32,] 4.890050
#> [33,] 5.127204
#> [34,] 5.317974
#> [35,] 5.212719
#> [36,] 4.932458
#> [37,] 5.129464
#> [38,] 4.910691
#> [39,] 6.220816
#> [40,] 6.336589
#> [41,] 6.109429
#> [42,] 6.280500
#> [43,] 5.204318
#> [44,] 6.182060
#> [45,] 5.668333
#> [46,] 5.142687
#> [47,] 6.005284
#> [48,] 5.688583
#> [49,] 5.457387
#> [50,] 6.140691
#> [51,] 5.707504
#> [52,] 6.247146
#> [53,] 6.262478
#> [54,] 5.694635
#> [55,] 6.247862
#> [56,] 6.144328
#> [57,] 6.276506
#> [58,] 6.296937
#> [59,] 5.282725
#> [60,] 5.495523
#> [61,] 5.459322
#> [62,] 5.537781
#> [63,] 6.242001
#> [64,] 6.223752
#> [65,] 6.067748
#> [66,] 5.802462
#> [67,] 5.635231
#> [68,] 5.917087
#> [69,] 6.273111
#> [70,] 5.819931
#> [71,] 5.833788
#> [72,] 5.873289
#> [73,] 5.015609
#> [74,] 5.769503
#> [75,] 6.704795
#> [76,] 6.870084
#> [77,] 6.445517
#> [78,] 6.740277
#> [79,] 7.253793
#> [80,] 6.581417
#> [81,] 6.285423
#> [82,] 6.232806
#> [83,] 6.242707
#> [84,] 6.179686
#> [85,] 6.267134
#> [86,] 7.795450
#> [87,] 7.541891
#> [88,] 6.312717
#> [89,] 7.576167
#> [90,] 6.578978
#> [91,] 6.874760
#> [92,] 6.719281
#> [93,] 7.603128
#> [94,] 6.493966
#> [95,] 6.281089
#> [96,] 6.372074
#> [97,] 6.376361
#> [98,] 6.374356
#> [99,] 6.404666
#> [100,] 6.777900
#> [101,] 6.352684
#> [102,] 6.289097
#> [103,] 6.203257
#> [104,] 6.311141
#> [105,] 6.264772