8.8. Data Set Evaluator

DataSetEvaluator lets you access:

  • summary statistics like mean absolute error (MAE) and root-mean-squared error (RMSE)

  • partial contributions to the loss function value

  • tables with reference and predicted values in columns next to each other

  • grouped summary statistics, partial contributions, reference-vs-prediction based on the extractor, expression, or any metadata key-value pairs.

To follow along with the below example, either

8.8.1. Load a data_set_predictions.yaml file

The most common way to use a DataSetEvaluator is to load the data_set_predictions.yaml file produced during an optimization.

from scm.params import *
import os

#if you go via ParAMSJob:
#job = ParAMSJob.load_external('/path/results/')
#dse = job.results.get_data_set_evaluator()

#to just load the .yaml file:
yaml_file = os.path.expandvars('$AMSHOME/scripting/scm/params/examples/ZnS_ReaxFF/example_output/best/data_set_predictions.yaml')
dse = DataSetEvaluator(yaml_file)

8.8.1.1. Summary statistics (stats.txt)

The results can be grouped in different ways. By default, the data is grouped first by 'Extractor' and then by 'Expression'. To get a file like stats.txt, simply call the .str() method:

print(dse.str())
Group/Expression                                        N        MAE          RMSE*        Unit*          Weight     Loss*        Contribution[%]
-------------------------------------------------------------------------------------------------------------------------------------------------
Total                                                   466      0.59112      +0.97646     Mixed!         50.000     236.128      100.00     Total

   forces                                               297      0.76691      +1.02728     eV/angstrom    2.000      89.301       37.82      Extractor
      band_distorted_clean_110                          144      0.89717      +1.14169     eV/angstrom    1.000      54.772       23.20      Expression
      band_distorted_ads_110                            153      0.64432      +0.90649     eV/angstrom    1.000      34.529       14.62      Expression

   pes                                                  30       0.09157      +0.14448     eV             13.000     65.858       27.89      Extractor
      pes('bondscan_h2s_pbesol', relative_to=5)         10       0.16696      +0.22823     eV             3.000      52.759       22.34      Expression
      pes('rocksalt', relative_to=2)                    5        0.06917      +0.08621     eV             4.000      10.036       4.25       Expression
      pes('anglescan_h2s_pbesol', relative_to=2)        10       0.06506      +0.08106     eV             1.000      2.218        0.94       Expression
      pes('zincblende', relative_to=2)                  5        0.01622      +0.02237     eV             5.000      0.844        0.36       Expression

   energy                                               9        0.63328      +1.63647     eV             9.000      26.796       11.35      Extractor
      1.0*band_h2s_110-1.0*band_110-1.0*band_h2s        1        0.18263      -0.18263     eV             1.000      11.261       4.77       Expression
      1.0*band_110-1.0*band_distorted_clean_110         1        4.89415      -4.89415     eV             1.000      5.988        2.54       Expression
      0.041667*band_110_noconstraints-0.5*wurtzite_sp   1        0.13315      -0.13315     eV             1.000      5.985        2.53       Expression
      1.0*zincblende_sp-0.5*wurtzite_sp                 1        0.06582      +0.06582     eV             1.000      1.463        0.62       Expression
      0.5*wurtzite_sp-0.03125*sulfur-0.5*Zn             1        0.06549      -0.06549     eV             1.000      1.448        0.61       Expression
      1.0*rocksalt_sp-0.03125*sulfur-0.5*Zn             1        0.03069      +0.03069     eV             1.000      0.318        0.13       Expression
      1.0*zincblende_sp-1.0*rocksalt_sp                 1        0.03036      -0.03036     eV             1.000      0.311        0.13       Expression
      1.0*band_h2s_110-1.0*band_distorted_ads_110       1        0.29687      +0.29687     eV             1.000      0.022        0.01       Expression
      1.0*zincblende_sp-0.03125*sulfur-0.5*Zn           1        0.00033      +0.00033     eV             1.000      0.000        0.00       Expression

   angle                                                7        3.59374      +3.79878     degree         7.000      25.254       10.69      Extractor
      angle('band_110_noconstraints',39,19,24)          1        4.90284      +4.90284     degree         1.000      6.009        2.55       Expression
      angle('band_110_noconstraints',42,39,19)          1        4.75211      -4.75211     degree         1.000      5.646        2.39       Expression
      angle('band_110_noconstraints',39,19,31)          1        4.74455      -4.74455     degree         1.000      5.628        2.38       Expression
      angle('band_110_noconstraints',39,42,47)          1        4.11002      +4.11002     degree         1.000      4.223        1.79       Expression
      angle('band_110_noconstraints',23,27,47)          1        2.45845      -2.45845     degree         1.000      1.511        0.64       Expression
      angle('band_h2s',1,0,2)                           1        2.39175      +2.39175     degree         1.000      1.430        0.61       Expression
      angle('band_110_noconstraints',27,47,4)           1        1.79647      +1.79647     degree         1.000      0.807        0.34       Expression

   distance                                             12       0.04638      +0.06270     angstrom       12.000     18.869       7.99       Extractor
      distance('band_h2s_110',34,19)                    1        0.14270      -0.14270     angstrom       1.000      8.145        3.45       Expression
      distance('band_110_noconstraints',4,47)           1        0.09542      -0.09542     angstrom       1.000      3.642        1.54       Expression
      distance('band_h2s',0,1)                          1        0.07365      -0.07365     angstrom       1.000      2.170        0.92       Expression
      distance('band_110_noconstraints',3,12)           1        0.06262      +0.06262     angstrom       1.000      1.568        0.66       Expression
      distance('band_110_noconstraints',39,42)          1        0.05802      +0.05802     angstrom       1.000      1.347        0.57       Expression
      distance('band_110_noconstraints',42,47)          1        0.05528      +0.05528     angstrom       1.000      1.222        0.52       Expression
      distance('band_110_noconstraints',27,47)          1        0.03551      +0.03551     angstrom       1.000      0.504        0.21       Expression
      distance('band_h2s_110',14, 37)                   1        0.02452      -0.02452     angstrom       1.000      0.240        0.10       Expression
      distance('band_110_noconstraints',4,12)           1        0.00868      +0.00868     angstrom       1.000      0.030        0.01       Expression
      distance('band_110_noconstraints',19, 42)         1        0.00011      -0.00011     angstrom       1.000      0.000        0.00       Expression
      distance('band_110_noconstraints',12,15)          1        0.00004      -0.00004     angstrom       1.000      0.000        0.00       Expression
      distance('band_110_noconstraints',5,28)           1        0.00000      -0.00000     angstrom       1.000      0.000        0.00       Expression

   charges                                              110      0.10562      +0.11519     au             6.000      9.138        3.87       Extractor
      zincblende_sp                                     2        0.15138      +0.15138     au             1.000      2.292        0.97       Expression
      wurtzite_sp                                       4        0.14789      +0.14789     au             1.000      2.187        0.93       Expression
      rocksalt_sp                                       2        0.13159      +0.13159     au             1.000      1.731        0.73       Expression
      band_110_noconstraints                            48       0.11118      +0.11430     au             1.000      1.306        0.55       Expression
      band_h2s_110                                      51       0.09724      +0.11327     au             1.000      1.283        0.54       Expression
      band_h2s                                          3        0.05482      +0.05814     au             1.000      0.338        0.14       Expression

   dihedral                                             1        1.91010      +1.91010     degree         1.000      0.912        0.39       Extractor
      dihedral('band_110_noconstraints',9,28,27,23)     1        1.91010      +1.91010     degree         1.000      0.912        0.39       Expression
-------------------------------------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.

Note that the extractor names for the various expressions are not shown if there are no arguments to the extractor. This makes the output more readable.

You can access individual entries from the above table as follows:

print(len(dse.results['charges'].residuals)) # the N for the charges
110
print(dse.results['charges']['zincblende_sp'].mae) # MAE for an expression
0.151381045
print(dse.results['forces'].rmse) # RMSE for an extractor
1.0272752938737622
print(dse.results['forces'].unit) # unit for an extractor
eV/angstrom
print(dse.results['charges']['wurtzite_sp'].weight) # the weight is returned as a scalar, even for array reference values
1.0
print(dse.results['energy']['1.0*zincblende_sp-0.5*wurtzite_sp'].my_loss) #"my_loss" refers to the loss of the individual entry
1.4626476245701765
print(dse.results['forces'].contribution) # fractional contribution to the weighted loss function
0.3781877244315181
print(dse.results.total_loss) # total loss function value
236.1283753165323
print(dse.results.loss_type) # type of loss function
SSE()

You can also just print a summary of a part of the table:

print(dse.results['forces'].str())
Group/Expression              N        MAE          RMSE*        Unit*          Weight     Loss*        Contribution[%]
-----------------------------------------------------------------------------------------------------------------------
forces                        297      0.76691      +1.02728     eV/angstrom    2.000      89.301       37.82      Extractor
   band_distorted_clean_110   144      0.89717      +1.14169     eV/angstrom    1.000      54.772       23.20      Expression
   band_distorted_ads_110     153      0.64432      +0.90649     eV/angstrom    1.000      34.529       14.62      Expression
-----------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: None. The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.

You can also modify the grouping to only go one level deep:

dse.group_by(('Extractor',)) # the default is group_by(('Extractor', 'Expression'))
print(dse.str())
Group/Expression N        MAE          RMSE*        Unit*          Weight     Loss*        Contribution[%]
-------------------------------------------------------------------------------------------------------
Total         466      0.59112      +0.97646     Mixed!         50.000     236.128      100.00     Total
   forces     297      0.76691      +1.02728     eV/angstrom    2.000      89.301       37.82      Extractor
   pes        30       0.09157      +0.14448     eV             13.000     65.858       27.89      Extractor
   energy     9        0.63328      +1.63647     eV             9.000      26.796       11.35      Extractor
   angle      7        3.59374      +3.79878     degree         7.000      25.254       10.69      Extractor
   distance   12       0.04638      +0.06270     angstrom       12.000     18.869       7.99       Extractor
   charges    110      0.10562      +0.11519     au             6.000      9.138        3.87       Extractor
   dihedral   1        1.91010      +1.91010     degree         1.000      0.912        0.39       Extractor
-------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.

If there is metadata attached to the training set entries, you can also group by those. For example, when creating a training set with a ResultsImporter, the Group and SubGroup metadata are automatically set:

dse.group_by(('Group', 'SubGroup'))
print(dse.str())
Group/Expression                 N        MAE          RMSE*        Unit*          Weight     Loss*        Contribution[%]
--------------------------------------------------------------------------------------------------------------------------
Total                            466      0.59112      +0.97646     Mixed!         50.000     236.128      100.00     Total

   Forces                        297      0.76691      +1.02728     eV/angstrom    2.000      89.301       37.82      Group
      band_distorted_clean_110   144      0.89717      +1.14169     eV/angstrom    1.000      54.772       23.20      SubGroup
      band_distorted_ads_110     153      0.64432      +0.90649     eV/angstrom    1.000      34.529       14.62      SubGroup

   None                          31       0.15023      +0.37134     Mixed!         14.000     66.770       28.28      Group
      bondscan_h2s_pbesol        10       0.16696      +0.22823     eV             3.000      52.759       22.34      SubGroup
      rocksalt                   5        0.06917      +0.08621     eV             4.000      10.036       4.25       SubGroup
      anglescan_h2s_pbesol       10       0.06506      +0.08106     eV             1.000      2.218        0.94       SubGroup
      band_110_noconstraints     1        1.91010      +1.91010     degree         1.000      0.912        0.39       SubGroup
      zincblende                 5        0.01622      +0.02237     eV             5.000      0.844        0.36       SubGroup

   ReactionEnergy                9        0.63328      +1.63647     eV             9.000      26.796       11.35      Group
      None                       9        0.63328      +1.63647     eV             9.000      26.796       11.35      SubGroup

   Angles                        7        3.59374      +3.79878     degree         7.000      25.254       10.69      Group
      band_110_noconstraints     6        3.79407      +3.98528     degree         6.000      23.824       10.09      SubGroup
      band_h2s                   1        2.39175      +2.39175     degree         1.000      1.430        0.61       SubGroup

   Distances                     12       0.04638      +0.06270     angstrom       12.000     18.869       7.99       Group
      band_h2s_110               2        0.08361      +0.10238     angstrom       2.000      8.386        3.55       SubGroup
      band_110_noconstraints     9        0.03508      +0.04806     angstrom       9.000      8.314        3.52       SubGroup
      band_h2s                   1        0.07365      -0.07365     angstrom       1.000      2.170        0.92       SubGroup

   Charges                       110      0.10562      +0.11519     au             6.000      9.138        3.87       Group
      zincblende_sp              2        0.15138      +0.15138     au             1.000      2.292        0.97       SubGroup
      wurtzite_sp                4        0.14789      +0.14789     au             1.000      2.187        0.93       SubGroup
      rocksalt_sp                2        0.13159      +0.13159     au             1.000      1.731        0.73       SubGroup
      band_110_noconstraints     48       0.11118      +0.11430     au             1.000      1.306        0.55       SubGroup
      band_h2s_110               51       0.09724      +0.11327     au             1.000      1.283        0.54       SubGroup
      band_h2s                   3        0.05482      +0.05814     au             1.000      0.338        0.14       SubGroup
--------------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.
print(dse.results['Forces'].mae) # capital F in the Group metadata
0.7669139308273065

8.8.1.2. Access individual predictions and reference values (scatter_plots/)

Call the .detailed_string() method to get files similar to scatter_plots/forces.txt etc.

dse.group_by(('Extractor', 'Expression')) # reset to the original grouping
results = dse.results['pes'] # look at the results for the pes extractor
print(results.detailed_string())
#Reference     Prediction     Unit           Sigma    Weight    WSE*      Row*  Col*  Expression
#------------------------------------------------------------------------------------------------------------------------
+0.419         +0.392          eV             0.054    1.0000    0.255     0     0     pes('zincblende', relative_to=2)
+0.092         +0.092          eV             0.054    1.0000    0.000     1     0     pes('zincblende', relative_to=2)
+0.000         +0.000          eV             0.054    1.0000    0.000     2     0     pes('zincblende', relative_to=2)
+0.078         +0.092          eV             0.054    1.0000    0.065     3     0     pes('zincblende', relative_to=2)
+0.278         +0.317          eV             0.054    1.0000    0.525     4     0     pes('zincblende', relative_to=2)
+0.474         +0.336          eV             0.054    0.8000    5.171     0     0     pes('rocksalt', relative_to=2)
+0.111         +0.083          eV             0.054    0.8000    0.202     1     0     pes('rocksalt', relative_to=2)
+0.000         +0.000          eV             0.054    0.8000    0.000     2     0     pes('rocksalt', relative_to=2)
+0.073         +0.006          eV             0.054    0.8000    1.220     3     0     pes('rocksalt', relative_to=2)
+0.274         +0.161          eV             0.054    0.8000    3.442     4     0     pes('rocksalt', relative_to=2)
+1.335         +0.857          eV             0.054    0.3000    23.227    0     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.732         +0.560          eV             0.054    0.3000    2.990     1     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.342         +0.333          eV             0.054    0.3000    0.007     2     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.114         +0.169          eV             0.054    0.3000    0.306     3     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.011         +0.060          eV             0.054    0.3000    0.252     4     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.000         +0.000          eV             0.054    0.3000    0.000     5     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.059         -0.019          eV             0.054    0.3000    0.614     6     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.170         -0.002          eV             0.054    0.3000    2.999     7     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.319         +0.044          eV             0.054    0.3000    7.654     8     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.495         +0.114          eV             0.054    0.3000    14.710    9     0     pes('bondscan_h2s_pbesol', relative_to=5)
+0.650         +0.700          eV             0.054    0.1000    0.086     0     0     pes('anglescan_h2s_pbesol', relative_to=2)
+0.204         +0.186          eV             0.054    0.1000    0.010     1     0     pes('anglescan_h2s_pbesol', relative_to=2)
+0.000         +0.000          eV             0.054    0.1000    0.000     2     0     pes('anglescan_h2s_pbesol', relative_to=2)
+0.040         +0.096          eV             0.054    0.1000    0.106     3     0     pes('anglescan_h2s_pbesol', relative_to=2)
+0.307         +0.422          eV             0.054    0.1000    0.451     4     0     pes('anglescan_h2s_pbesol', relative_to=2)
+0.768         +0.903          eV             0.054    0.1000    0.616     5     0     pes('anglescan_h2s_pbesol', relative_to=2)
+1.376         +1.453          eV             0.054    0.1000    0.197     6     0     pes('anglescan_h2s_pbesol', relative_to=2)
+2.043         +1.990          eV             0.054    0.1000    0.093     7     0     pes('anglescan_h2s_pbesol', relative_to=2)
+2.597         +2.458          eV             0.054    0.1000    0.657     8     0     pes('anglescan_h2s_pbesol', relative_to=2)
+2.818         +2.825          eV             0.054    0.1000    0.002     9     0     pes('anglescan_h2s_pbesol', relative_to=2)
#------------------------------------------------------------------------------------------------------------------------
#WSE*: Weighted Squared Error: weight*([reference-prediction]/sigma)**2
#Row*, Col*: For scalars both numbers are 0. For 1D arrays Col is 0.
print(results.reference_values) # list of reference values
[0.419314, 0.09164115, 0.0, 0.07841534, 0.27771948, 0.47439122, 0.11065178, 0.0, 0.07282489, 0.27377707, 1.33538276, 0.73173446, 0.34163308, 0.11433319, 0.01054105, 0.0, 0.05929865, 0.17040413, 0.3193521, 0.49523743, 0.64987859, 0.20385679, 0.0, 0.04011389, 0.30652919, 0.76818668, 1.37634888, 2.04264645, 2.59729761, 2.8175004]
print(results.predictions) # list of predicted values
[0.39185008, 0.09196291, 0.0, 0.09227838, 0.31714676, 0.33602331, 0.08328586, 0.0, 0.00560645, 0.16088318, 0.85651058, 0.55991041, 0.33346583, 0.16932453, 0.06042187, 0.0, -0.0185412, -0.00165861, 0.04446513, 0.11414837, 0.70031273, 0.18640377, 0.0, 0.09607728, 0.42212549, 0.90330381, 1.45275737, 1.99029061, 2.45778418, 2.825236]
print(results.unit) # the unit
eV
print(results.accuracies) # the Sigma values (per expression)
[0.054422772491975996, 0.054422772491975996, 0.054422772491975996, 0.054422772491975996]
print(results.weights) # the Weights (per reference/prediction)
[1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
print(results.contributions) # list of individual contributions (per expression)
[0.0035761475951556648, 0.04250420478293943, 0.2234329155570571, 0.009394469642807931]
print(results.expressions) # list of expressions
["pes('zincblende', relative_to=2)", "pes('rocksalt', relative_to=2)", "pes('bondscan_h2s_pbesol', relative_to=5)", "pes('anglescan_h2s_pbesol', relative_to=2)"]

Note that the number of reference values is different from the number of expressions when the reference values are arrays. To get the reference values per expression:

for e in results.expressions:
    print(f"Expression: {e}, Ref. values: {results[e].reference_values}")
Expression: pes('zincblende', relative_to=2), Ref. values: [0.419314, 0.09164115, 0.0, 0.07841534, 0.27771948]
Expression: pes('rocksalt', relative_to=2), Ref. values: [0.47439122, 0.11065178, 0.0, 0.07282489, 0.27377707]
Expression: pes('bondscan_h2s_pbesol', relative_to=5), Ref. values: [1.33538276, 0.73173446, 0.34163308, 0.11433319, 0.01054105, 0.0, 0.05929865, 0.17040413, 0.3193521, 0.49523743]
Expression: pes('anglescan_h2s_pbesol', relative_to=2), Ref. values: [0.64987859, 0.20385679, 0.0, 0.04011389, 0.30652919, 0.76818668, 1.37634888, 2.04264645, 2.59729761, 2.8175004]

8.8.2. Calculate reference or run jobs

DataSetEvaluator can also be used in the following ways:

  • DataSetEvaluator.calculate_reference() will evaluate a data set with any engine settings and set the reference values

  • DataSetEvaluator.run() will evaluate the data_set with any engine settings, and provide many function to compare the predicted values with the reference values. You can only use run() if all data_set entries already have reference values.

8.8.2.1. Example: DataSetEvaluator.calculate_reference()

Note

The below examples use the plams.Settings class to define a computational engine. See the PLAMS documentation for more information about it.

from scm.params import *
from scm.plams import Settings

dse = DataSetEvaluator()

# any engine settings are possible
engine_settings = Settings()
engine_settings.input.ForceField.Type = 'UFF'

# a job collection is needed, can for example be loaded from disk
job_collection = JobCollection('job_collection.yaml')

# the data_set to be evaluated, can for example be loaded from disk
data_set = DataSet('data_set.yaml')

# print the original expression : reference value
print("Original reference values:")
for ds_entry in data_set:
   print("{}: {}".format(ds_entry.expression, ds_entry.reference))

# calculate reference. Set folder=None to not store the finished jobs on disk (can be faster)
# set overwrite=True to overwrite existing reference values
dse.calculate_reference(job_collection, data_set, engine_settings, overwrite=False, folder='saved_results')

# print the new expression : reference value
print("New reference values:")
for ds_entry in data_set:
   print("{}: {}".format(ds_entry.expression, ds_entry.reference))

8.8.2.2. Example: DataSetEvaluator.run()

from scm.params import *
from scm.plams import Settings

dse = DataSetEvaluator()

# any engine settings are possible
engine_settings = Settings()
engine_settings.input.ForceField.Type = 'UFF'

# a job collection is needed, can for example be loaded from disk
job_collection = JobCollection('job_collection.yaml')

# the data_set to be evaluated, can for example be loaded from disk
data_set = DataSet('data_set.yaml')

# run. Set folder=None to not store the finished jobs on disk (can be faster)
dse.run(job_collection, data_set, engine_settings, folder='saved_results')

# group the results by Extractor and then by Expression
dse.group_by(('Extractor', 'Expression'))

print(dse.str(stats=True, details=True))

# store the calculated results in a format that can later be
# used to initialize another DataSetEvaluator
dse.store('data_set_predictions.yaml')

8.8.3. DataSetEvaluator API

class DataSetEvaluator(data_set=None, total_loss=None, residuals=None, contributions=None, raw_predictions=None, predictions=None, modified_reference=None, loss=None)

Convenience class for evaluating a data_set with any engine.

Run the evaluation with the run() function.

Then group the results based on the Extractor, Expression, or metadata key-value pairs with the group_by() method.

Print the results with str(stats=True, details=True)

  • stats=True will give the mean absolute error, root mean squared error, and partial contributions to the loss function

  • details=True will give a table of prediction vs. reference

The results are stored in the results attribute. It is of type GroupedResults, and can be accessed as follows:

>>> dse = DataSetEvaluator()
>>> dse.run(job_collection, data_set, engine_settings)
>>> dse.group_by(('Group', 'SubGroup')) # for grouping by Group and SubGroup metadata keys
>>> dse.results.mae
>>> dse.results.rmse
>>> dse.results['Forces'].mae
>>> dse.results['Forces']['trajectory_1'].mae
>>> str(dse.results)
>>> dse.results.detailed_string()
>>> dse.results['Forces'].str()
>>> dse.results['Forces'].detailed_string()
>>> dse.results['Forces'].residuals
>>> dse.results['Forces'].predictions
>>> dse.results['Forces'].reference_values
etc.
__init__(data_set=None, total_loss=None, residuals=None, contributions=None, raw_predictions=None, predictions=None, modified_reference=None, loss=None)

Typically you should initialize this class without arguments, i.e., as

>>> dse = DataSetEvaluator()

data_set, predictions, residuals, contributions, total_loss can either be set in this constructor, or will internally be calculated with the run() method.

data_setDataSet

Dataset that was evaluated

total_lossfloat

Return value from data_set.evaluate(results, return_residuals=True)[0]

residuals: list

Return value from data_set.evaluate(results, return_residuals=True)[1]

contributions: list

Return value from data_set.evaluate(results, return_residuals=True)[2]

raw_predictions: list

Return value from data_set.evaluate(results, return_residuals=True)[3]

predictions: list

Return value from data_set.get_predictions(raw_predictions, return_reference=True)[0]

modified_referencelist

Return value from data_set.get_predictions(raw_predictions, return_reference=True)[1]

lossa LossFunction or str

The type of loss function that was used to calculate total_loss

calculate_reference(job_collection: JobCollection, data_set: DataSet, engine_settings, overwrite=False, use_pipe=True, folder=None, parallel=None, use_origin=False)

Method to calculate and set the reference values for the entries in data_set. This method will change the data_set!

The method does not modify the DataSetEvaluator instance.

engine_settingsSettings or EngineCollection

If a Settings instance, that will define the reference engine used to calculate all the jobs.

If an EngineCollection, every job in the job_collection must have a ReferenceEngineID (reference_engine), that is present in the EngineCollection. The settings will then be taken from the engine collection. If more than one engine is needed to evaluate the jobs, then you must pass in an EngineCollection.

overwritebool

If False, only calculate reference values for data set entries that have no reference value. If True, calculate all reference values.

use_originbool

If a job in the job_collection has the “Origin” metadata pointing to an ams.rkf results file on disk, then load results from that file instead of rerunning the job.

If both the “Origin” and “Frame” metadata keys exist, data will be taken from the correct frame in the trajectory.

If the Origin, Frame, and OriginalEnergyHartree metadata keys exist, then the energy will be taken from the OriginalEnergyHartree metadata if the ams.rkf in Origin cannot be loaded (for example if it exists on a different machine).

If loading data from the “Origin” or “OriginalEnergyHartree” fails, the job will be run.

job_collection, data_set, use_pipe, folder, and parallel have the same meaning as in the run() method.

run(job_collection: JobCollection, data_set: DataSet, engine_settings: Settings, loss='sse', use_pipe=True, folder=None, parallel=None, group_by=None)

Runs the jobs in the job collection using the engine defined by engine_settings, and evaluates the data_set expressions.

job_collectionJobCollection

The job collection containing the jobs

data_setDataSet

The data_set containing the expressions to be evaluated

engine_settingsSettings

The engine settings to be used. Example:

>>> engine_settings = Settings()
>>> engine_settings.input.ForceField.Model = 'UFF'
lossstr or Loss

The type of loss function

use_pipebool

Whether to use the pipe interface if possible. This will speed up the calculation. Cannot be combined with folder.

folderstr

If folder is not None, the results will be stored on disk in that folder. If the folder already exists, a new one is created. If set, will automatically disable the pipe interface.

parallelParallelLevels

Defaults to ParallelLevels(parametervectors=1, processes=1, threads=1). This will run N jobs in parallel, where N is the number of cores on the machine.

group_bytuple of str

Group results according to the tuple. The grouping can also be changed after the run with the group_by() method.

group_by(group_by)

Group the results according to group_by. The run() method needs to called before calling this method.

group_bytuple of str
>>> group_by(('Extractor')) # group by extractor
>>> group_by(('Extractor', 'Expression')) # group by extractor, then expression. The expression will be filtered
>>> group_by(('Group', 'SubGroup')) # group by the metadata key Group, then by the metadata key SubGroup
__str__()

Return str(self).