8.8. Data Set Evaluator¶
DataSetEvaluator
lets you access:
summary statistics like mean absolute error (MAE) and root-mean-squared error (RMSE)
partial contributions to the loss function value
tables with reference and predicted values in columns next to each other
grouped summary statistics, partial contributions, reference-vs-prediction based on the extractor, expression, or any metadata key-value pairs.
To follow along with the below example, either
Download
data_set_evaluator_demo.py
Download
data_set_evaluator_demo.ipynb
(see also: how to install Jupyterlab)
8.8.1. Load a data_set_predictions.yaml file¶
The most common way to use a DataSetEvaluator is to load the
data_set_predictions.yaml
file produced during an optimization.
from scm.params import *
import os
#if you go via ParAMSJob:
#job = ParAMSJob.load_external('/path/results/')
#dse = job.results.get_data_set_evaluator()
#to just load the .yaml file:
yaml_file = os.path.expandvars('$AMSHOME/scripting/scm/params/examples/ZnS_ReaxFF/example_output/best/data_set_predictions.yaml')
dse = DataSetEvaluator(yaml_file)
8.8.1.1. Summary statistics (stats.txt)¶
The results can be grouped in different ways. By default, the data is
grouped first by 'Extractor'
and then by 'Expression'
. To get a
file like stats.txt
, simply call the .str()
method:
print(dse.str())
Group/Expression N MAE RMSE* Unit* Weight Loss* Contribution[%]
-------------------------------------------------------------------------------------------------------------------------------------------------
Total 466 0.59112 +0.97646 Mixed! 50.000 236.128 100.00 Total
forces 297 0.76691 +1.02728 eV/angstrom 2.000 89.301 37.82 Extractor
band_distorted_clean_110 144 0.89717 +1.14169 eV/angstrom 1.000 54.772 23.20 Expression
band_distorted_ads_110 153 0.64432 +0.90649 eV/angstrom 1.000 34.529 14.62 Expression
pes 30 0.09157 +0.14448 eV 13.000 65.858 27.89 Extractor
pes('bondscan_h2s_pbesol', relative_to=5) 10 0.16696 +0.22823 eV 3.000 52.759 22.34 Expression
pes('rocksalt', relative_to=2) 5 0.06917 +0.08621 eV 4.000 10.036 4.25 Expression
pes('anglescan_h2s_pbesol', relative_to=2) 10 0.06506 +0.08106 eV 1.000 2.218 0.94 Expression
pes('zincblende', relative_to=2) 5 0.01622 +0.02237 eV 5.000 0.844 0.36 Expression
energy 9 0.63328 +1.63647 eV 9.000 26.796 11.35 Extractor
1.0*band_h2s_110-1.0*band_110-1.0*band_h2s 1 0.18263 -0.18263 eV 1.000 11.261 4.77 Expression
1.0*band_110-1.0*band_distorted_clean_110 1 4.89415 -4.89415 eV 1.000 5.988 2.54 Expression
0.041667*band_110_noconstraints-0.5*wurtzite_sp 1 0.13315 -0.13315 eV 1.000 5.985 2.53 Expression
1.0*zincblende_sp-0.5*wurtzite_sp 1 0.06582 +0.06582 eV 1.000 1.463 0.62 Expression
0.5*wurtzite_sp-0.03125*sulfur-0.5*Zn 1 0.06549 -0.06549 eV 1.000 1.448 0.61 Expression
1.0*rocksalt_sp-0.03125*sulfur-0.5*Zn 1 0.03069 +0.03069 eV 1.000 0.318 0.13 Expression
1.0*zincblende_sp-1.0*rocksalt_sp 1 0.03036 -0.03036 eV 1.000 0.311 0.13 Expression
1.0*band_h2s_110-1.0*band_distorted_ads_110 1 0.29687 +0.29687 eV 1.000 0.022 0.01 Expression
1.0*zincblende_sp-0.03125*sulfur-0.5*Zn 1 0.00033 +0.00033 eV 1.000 0.000 0.00 Expression
angle 7 3.59374 +3.79878 degree 7.000 25.254 10.69 Extractor
angle('band_110_noconstraints',39,19,24) 1 4.90284 +4.90284 degree 1.000 6.009 2.55 Expression
angle('band_110_noconstraints',42,39,19) 1 4.75211 -4.75211 degree 1.000 5.646 2.39 Expression
angle('band_110_noconstraints',39,19,31) 1 4.74455 -4.74455 degree 1.000 5.628 2.38 Expression
angle('band_110_noconstraints',39,42,47) 1 4.11002 +4.11002 degree 1.000 4.223 1.79 Expression
angle('band_110_noconstraints',23,27,47) 1 2.45845 -2.45845 degree 1.000 1.511 0.64 Expression
angle('band_h2s',1,0,2) 1 2.39175 +2.39175 degree 1.000 1.430 0.61 Expression
angle('band_110_noconstraints',27,47,4) 1 1.79647 +1.79647 degree 1.000 0.807 0.34 Expression
distance 12 0.04638 +0.06270 angstrom 12.000 18.869 7.99 Extractor
distance('band_h2s_110',34,19) 1 0.14270 -0.14270 angstrom 1.000 8.145 3.45 Expression
distance('band_110_noconstraints',4,47) 1 0.09542 -0.09542 angstrom 1.000 3.642 1.54 Expression
distance('band_h2s',0,1) 1 0.07365 -0.07365 angstrom 1.000 2.170 0.92 Expression
distance('band_110_noconstraints',3,12) 1 0.06262 +0.06262 angstrom 1.000 1.568 0.66 Expression
distance('band_110_noconstraints',39,42) 1 0.05802 +0.05802 angstrom 1.000 1.347 0.57 Expression
distance('band_110_noconstraints',42,47) 1 0.05528 +0.05528 angstrom 1.000 1.222 0.52 Expression
distance('band_110_noconstraints',27,47) 1 0.03551 +0.03551 angstrom 1.000 0.504 0.21 Expression
distance('band_h2s_110',14, 37) 1 0.02452 -0.02452 angstrom 1.000 0.240 0.10 Expression
distance('band_110_noconstraints',4,12) 1 0.00868 +0.00868 angstrom 1.000 0.030 0.01 Expression
distance('band_110_noconstraints',19, 42) 1 0.00011 -0.00011 angstrom 1.000 0.000 0.00 Expression
distance('band_110_noconstraints',12,15) 1 0.00004 -0.00004 angstrom 1.000 0.000 0.00 Expression
distance('band_110_noconstraints',5,28) 1 0.00000 -0.00000 angstrom 1.000 0.000 0.00 Expression
charges 110 0.10562 +0.11519 au 6.000 9.138 3.87 Extractor
zincblende_sp 2 0.15138 +0.15138 au 1.000 2.292 0.97 Expression
wurtzite_sp 4 0.14789 +0.14789 au 1.000 2.187 0.93 Expression
rocksalt_sp 2 0.13159 +0.13159 au 1.000 1.731 0.73 Expression
band_110_noconstraints 48 0.11118 +0.11430 au 1.000 1.306 0.55 Expression
band_h2s_110 51 0.09724 +0.11327 au 1.000 1.283 0.54 Expression
band_h2s 3 0.05482 +0.05814 au 1.000 0.338 0.14 Expression
dihedral 1 1.91010 +1.91010 degree 1.000 0.912 0.39 Extractor
dihedral('band_110_noconstraints',9,28,27,23) 1 1.91010 +1.91010 degree 1.000 0.912 0.39 Expression
-------------------------------------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.
Note that the extractor names for the various expressions are not shown if there are no arguments to the extractor. This makes the output more readable.
You can access individual entries from the above table as follows:
print(len(dse.results['charges'].residuals)) # the N for the charges
110
print(dse.results['charges']['zincblende_sp'].mae) # MAE for an expression
0.151381045
print(dse.results['forces'].rmse) # RMSE for an extractor
1.0272752938737622
print(dse.results['forces'].unit) # unit for an extractor
eV/angstrom
print(dse.results['charges']['wurtzite_sp'].weight) # the weight is returned as a scalar, even for array reference values
1.0
print(dse.results['energy']['1.0*zincblende_sp-0.5*wurtzite_sp'].my_loss) #"my_loss" refers to the loss of the individual entry
1.4626476245701765
print(dse.results['forces'].contribution) # fractional contribution to the weighted loss function
0.3781877244315181
print(dse.results.total_loss) # total loss function value
236.1283753165323
print(dse.results.loss_type) # type of loss function
SSE()
You can also just print a summary of a part of the table:
print(dse.results['forces'].str())
Group/Expression N MAE RMSE* Unit* Weight Loss* Contribution[%]
-----------------------------------------------------------------------------------------------------------------------
forces 297 0.76691 +1.02728 eV/angstrom 2.000 89.301 37.82 Extractor
band_distorted_clean_110 144 0.89717 +1.14169 eV/angstrom 1.000 54.772 23.20 Expression
band_distorted_ads_110 153 0.64432 +0.90649 eV/angstrom 1.000 34.529 14.62 Expression
-----------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: None. The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.
You can also modify the grouping to only go one level deep:
dse.group_by(('Extractor',)) # the default is group_by(('Extractor', 'Expression'))
print(dse.str())
Group/Expression N MAE RMSE* Unit* Weight Loss* Contribution[%]
-------------------------------------------------------------------------------------------------------
Total 466 0.59112 +0.97646 Mixed! 50.000 236.128 100.00 Total
forces 297 0.76691 +1.02728 eV/angstrom 2.000 89.301 37.82 Extractor
pes 30 0.09157 +0.14448 eV 13.000 65.858 27.89 Extractor
energy 9 0.63328 +1.63647 eV 9.000 26.796 11.35 Extractor
angle 7 3.59374 +3.79878 degree 7.000 25.254 10.69 Extractor
distance 12 0.04638 +0.06270 angstrom 12.000 18.869 7.99 Extractor
charges 110 0.10562 +0.11519 au 6.000 9.138 3.87 Extractor
dihedral 1 1.91010 +1.91010 degree 1.000 0.912 0.39 Extractor
-------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.
If there is metadata attached to the training set entries, you can also
group by those. For example, when creating a training set with a
ResultsImporter
, the Group
and SubGroup
metadata are
automatically set:
dse.group_by(('Group', 'SubGroup'))
print(dse.str())
Group/Expression N MAE RMSE* Unit* Weight Loss* Contribution[%]
--------------------------------------------------------------------------------------------------------------------------
Total 466 0.59112 +0.97646 Mixed! 50.000 236.128 100.00 Total
Forces 297 0.76691 +1.02728 eV/angstrom 2.000 89.301 37.82 Group
band_distorted_clean_110 144 0.89717 +1.14169 eV/angstrom 1.000 54.772 23.20 SubGroup
band_distorted_ads_110 153 0.64432 +0.90649 eV/angstrom 1.000 34.529 14.62 SubGroup
None 31 0.15023 +0.37134 Mixed! 14.000 66.770 28.28 Group
bondscan_h2s_pbesol 10 0.16696 +0.22823 eV 3.000 52.759 22.34 SubGroup
rocksalt 5 0.06917 +0.08621 eV 4.000 10.036 4.25 SubGroup
anglescan_h2s_pbesol 10 0.06506 +0.08106 eV 1.000 2.218 0.94 SubGroup
band_110_noconstraints 1 1.91010 +1.91010 degree 1.000 0.912 0.39 SubGroup
zincblende 5 0.01622 +0.02237 eV 5.000 0.844 0.36 SubGroup
ReactionEnergy 9 0.63328 +1.63647 eV 9.000 26.796 11.35 Group
None 9 0.63328 +1.63647 eV 9.000 26.796 11.35 SubGroup
Angles 7 3.59374 +3.79878 degree 7.000 25.254 10.69 Group
band_110_noconstraints 6 3.79407 +3.98528 degree 6.000 23.824 10.09 SubGroup
band_h2s 1 2.39175 +2.39175 degree 1.000 1.430 0.61 SubGroup
Distances 12 0.04638 +0.06270 angstrom 12.000 18.869 7.99 Group
band_h2s_110 2 0.08361 +0.10238 angstrom 2.000 8.386 3.55 SubGroup
band_110_noconstraints 9 0.03508 +0.04806 angstrom 9.000 8.314 3.52 SubGroup
band_h2s 1 0.07365 -0.07365 angstrom 1.000 2.170 0.92 SubGroup
Charges 110 0.10562 +0.11519 au 6.000 9.138 3.87 Group
zincblende_sp 2 0.15138 +0.15138 au 1.000 2.292 0.97 SubGroup
wurtzite_sp 4 0.14789 +0.14789 au 1.000 2.187 0.93 SubGroup
rocksalt_sp 2 0.13159 +0.13159 au 1.000 1.731 0.73 SubGroup
band_110_noconstraints 48 0.11118 +0.11430 au 1.000 1.306 0.55 SubGroup
band_h2s_110 51 0.09724 +0.11327 au 1.000 1.283 0.54 SubGroup
band_h2s 3 0.05482 +0.05814 au 1.000 0.338 0.14 SubGroup
--------------------------------------------------------------------------------------------------------------------------
The weighted total loss function is 236.128.
N: number of numbers averaged for the MAE/RMSE
MAE and RMSE: These are not weighted!
RMSE*: if N == 1 the signed residual (reference-prediction) is given instead of the RMSE.
Unit*: if the unit is "Mixed!" it means that the MAE and RMSE are meaningless.
Loss function type: SSE(). The loss function value is affected by the Weight and Sigma of data_set entries.
Contribution[%]: The contribution to the weighted loss function.
print(dse.results['Forces'].mae) # capital F in the Group metadata
0.7669139308273065
8.8.1.2. Access individual predictions and reference values (scatter_plots/)¶
Call the .detailed_string()
method to get files similar to
scatter_plots/forces.txt
etc.
dse.group_by(('Extractor', 'Expression')) # reset to the original grouping
results = dse.results['pes'] # look at the results for the pes extractor
print(results.detailed_string())
#Reference Prediction Unit Sigma Weight WSE* Row* Col* Expression
#------------------------------------------------------------------------------------------------------------------------
+0.419 +0.392 eV 0.054 1.0000 0.255 0 0 pes('zincblende', relative_to=2)
+0.092 +0.092 eV 0.054 1.0000 0.000 1 0 pes('zincblende', relative_to=2)
+0.000 +0.000 eV 0.054 1.0000 0.000 2 0 pes('zincblende', relative_to=2)
+0.078 +0.092 eV 0.054 1.0000 0.065 3 0 pes('zincblende', relative_to=2)
+0.278 +0.317 eV 0.054 1.0000 0.525 4 0 pes('zincblende', relative_to=2)
+0.474 +0.336 eV 0.054 0.8000 5.171 0 0 pes('rocksalt', relative_to=2)
+0.111 +0.083 eV 0.054 0.8000 0.202 1 0 pes('rocksalt', relative_to=2)
+0.000 +0.000 eV 0.054 0.8000 0.000 2 0 pes('rocksalt', relative_to=2)
+0.073 +0.006 eV 0.054 0.8000 1.220 3 0 pes('rocksalt', relative_to=2)
+0.274 +0.161 eV 0.054 0.8000 3.442 4 0 pes('rocksalt', relative_to=2)
+1.335 +0.857 eV 0.054 0.3000 23.227 0 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.732 +0.560 eV 0.054 0.3000 2.990 1 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.342 +0.333 eV 0.054 0.3000 0.007 2 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.114 +0.169 eV 0.054 0.3000 0.306 3 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.011 +0.060 eV 0.054 0.3000 0.252 4 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.000 +0.000 eV 0.054 0.3000 0.000 5 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.059 -0.019 eV 0.054 0.3000 0.614 6 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.170 -0.002 eV 0.054 0.3000 2.999 7 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.319 +0.044 eV 0.054 0.3000 7.654 8 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.495 +0.114 eV 0.054 0.3000 14.710 9 0 pes('bondscan_h2s_pbesol', relative_to=5)
+0.650 +0.700 eV 0.054 0.1000 0.086 0 0 pes('anglescan_h2s_pbesol', relative_to=2)
+0.204 +0.186 eV 0.054 0.1000 0.010 1 0 pes('anglescan_h2s_pbesol', relative_to=2)
+0.000 +0.000 eV 0.054 0.1000 0.000 2 0 pes('anglescan_h2s_pbesol', relative_to=2)
+0.040 +0.096 eV 0.054 0.1000 0.106 3 0 pes('anglescan_h2s_pbesol', relative_to=2)
+0.307 +0.422 eV 0.054 0.1000 0.451 4 0 pes('anglescan_h2s_pbesol', relative_to=2)
+0.768 +0.903 eV 0.054 0.1000 0.616 5 0 pes('anglescan_h2s_pbesol', relative_to=2)
+1.376 +1.453 eV 0.054 0.1000 0.197 6 0 pes('anglescan_h2s_pbesol', relative_to=2)
+2.043 +1.990 eV 0.054 0.1000 0.093 7 0 pes('anglescan_h2s_pbesol', relative_to=2)
+2.597 +2.458 eV 0.054 0.1000 0.657 8 0 pes('anglescan_h2s_pbesol', relative_to=2)
+2.818 +2.825 eV 0.054 0.1000 0.002 9 0 pes('anglescan_h2s_pbesol', relative_to=2)
#------------------------------------------------------------------------------------------------------------------------
#WSE*: Weighted Squared Error: weight*([reference-prediction]/sigma)**2
#Row*, Col*: For scalars both numbers are 0. For 1D arrays Col is 0.
print(results.reference_values) # list of reference values
[0.419314, 0.09164115, 0.0, 0.07841534, 0.27771948, 0.47439122, 0.11065178, 0.0, 0.07282489, 0.27377707, 1.33538276, 0.73173446, 0.34163308, 0.11433319, 0.01054105, 0.0, 0.05929865, 0.17040413, 0.3193521, 0.49523743, 0.64987859, 0.20385679, 0.0, 0.04011389, 0.30652919, 0.76818668, 1.37634888, 2.04264645, 2.59729761, 2.8175004]
print(results.predictions) # list of predicted values
[0.39185008, 0.09196291, 0.0, 0.09227838, 0.31714676, 0.33602331, 0.08328586, 0.0, 0.00560645, 0.16088318, 0.85651058, 0.55991041, 0.33346583, 0.16932453, 0.06042187, 0.0, -0.0185412, -0.00165861, 0.04446513, 0.11414837, 0.70031273, 0.18640377, 0.0, 0.09607728, 0.42212549, 0.90330381, 1.45275737, 1.99029061, 2.45778418, 2.825236]
print(results.unit) # the unit
eV
print(results.accuracies) # the Sigma values (per expression)
[0.054422772491975996, 0.054422772491975996, 0.054422772491975996, 0.054422772491975996]
print(results.weights) # the Weights (per reference/prediction)
[1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
print(results.contributions) # list of individual contributions (per expression)
[0.0035761475951556648, 0.04250420478293943, 0.2234329155570571, 0.009394469642807931]
print(results.expressions) # list of expressions
["pes('zincblende', relative_to=2)", "pes('rocksalt', relative_to=2)", "pes('bondscan_h2s_pbesol', relative_to=5)", "pes('anglescan_h2s_pbesol', relative_to=2)"]
Note that the number of reference values is different from the number of expressions when the reference values are arrays. To get the reference values per expression:
for e in results.expressions:
print(f"Expression: {e}, Ref. values: {results[e].reference_values}")
Expression: pes('zincblende', relative_to=2), Ref. values: [0.419314, 0.09164115, 0.0, 0.07841534, 0.27771948]
Expression: pes('rocksalt', relative_to=2), Ref. values: [0.47439122, 0.11065178, 0.0, 0.07282489, 0.27377707]
Expression: pes('bondscan_h2s_pbesol', relative_to=5), Ref. values: [1.33538276, 0.73173446, 0.34163308, 0.11433319, 0.01054105, 0.0, 0.05929865, 0.17040413, 0.3193521, 0.49523743]
Expression: pes('anglescan_h2s_pbesol', relative_to=2), Ref. values: [0.64987859, 0.20385679, 0.0, 0.04011389, 0.30652919, 0.76818668, 1.37634888, 2.04264645, 2.59729761, 2.8175004]
8.8.2. Calculate reference or run jobs¶
DataSetEvaluator
can also be used in the following ways:
DataSetEvaluator.calculate_reference()
will evaluate a data set with any engine settings and set the reference valuesDataSetEvaluator.run()
will evaluate the data_set with any engine settings, and provide many function to compare the predicted values with the reference values. You can only userun()
if all data_set entries already have reference values.
8.8.2.1. Example: DataSetEvaluator.calculate_reference()¶
Note
The below examples use the plams.Settings
class to define a computational engine.
See the PLAMS documentation for more information about it.
from scm.params import *
from scm.plams import Settings
dse = DataSetEvaluator()
# any engine settings are possible
engine_settings = Settings()
engine_settings.input.ForceField.Type = 'UFF'
# a job collection is needed, can for example be loaded from disk
job_collection = JobCollection('job_collection.yaml')
# the data_set to be evaluated, can for example be loaded from disk
data_set = DataSet('data_set.yaml')
# print the original expression : reference value
print("Original reference values:")
for ds_entry in data_set:
print("{}: {}".format(ds_entry.expression, ds_entry.reference))
# calculate reference. Set folder=None to not store the finished jobs on disk (can be faster)
# set overwrite=True to overwrite existing reference values
dse.calculate_reference(job_collection, data_set, engine_settings, overwrite=False, folder='saved_results')
# print the new expression : reference value
print("New reference values:")
for ds_entry in data_set:
print("{}: {}".format(ds_entry.expression, ds_entry.reference))
8.8.2.2. Example: DataSetEvaluator.run()¶
from scm.params import *
from scm.plams import Settings
dse = DataSetEvaluator()
# any engine settings are possible
engine_settings = Settings()
engine_settings.input.ForceField.Type = 'UFF'
# a job collection is needed, can for example be loaded from disk
job_collection = JobCollection('job_collection.yaml')
# the data_set to be evaluated, can for example be loaded from disk
data_set = DataSet('data_set.yaml')
# run. Set folder=None to not store the finished jobs on disk (can be faster)
dse.run(job_collection, data_set, engine_settings, folder='saved_results')
# group the results by Extractor and then by Expression
dse.group_by(('Extractor', 'Expression'))
print(dse.str(stats=True, details=True))
# store the calculated results in a format that can later be
# used to initialize another DataSetEvaluator
dse.store('data_set_predictions.yaml')
8.8.3. DataSetEvaluator API¶
- class DataSetEvaluator(data_set=None, total_loss=None, residuals=None, contributions=None, raw_predictions=None, predictions=None, modified_reference=None, loss=None)¶
Convenience class for evaluating a data_set with any engine.
Run the evaluation with the
run()
function.Then group the results based on the Extractor, Expression, or metadata key-value pairs with the
group_by()
method.Print the results with
str(stats=True, details=True)
stats=True will give the mean absolute error, root mean squared error, and partial contributions to the loss function
details=True will give a table of prediction vs. reference
The results are stored in the
results
attribute. It is of typeGroupedResults
, and can be accessed as follows:>>> dse = DataSetEvaluator() >>> dse.run(job_collection, data_set, engine_settings) >>> dse.group_by(('Group', 'SubGroup')) # for grouping by Group and SubGroup metadata keys >>> dse.results.mae >>> dse.results.rmse >>> dse.results['Forces'].mae >>> dse.results['Forces']['trajectory_1'].mae >>> str(dse.results) >>> dse.results.detailed_string() >>> dse.results['Forces'].str() >>> dse.results['Forces'].detailed_string() >>> dse.results['Forces'].residuals >>> dse.results['Forces'].predictions >>> dse.results['Forces'].reference_values etc.
- __init__(data_set=None, total_loss=None, residuals=None, contributions=None, raw_predictions=None, predictions=None, modified_reference=None, loss=None)¶
Typically you should initialize this class without arguments, i.e., as
>>> dse = DataSetEvaluator()
data_set, predictions, residuals, contributions, total_loss can either be set in this constructor, or will internally be calculated with the
run()
method.- data_setDataSet
Dataset that was evaluated
- total_lossfloat
Return value from data_set.evaluate(results, return_residuals=True)[0]
- residuals: list
Return value from data_set.evaluate(results, return_residuals=True)[1]
- contributions: list
Return value from data_set.evaluate(results, return_residuals=True)[2]
- raw_predictions: list
Return value from data_set.evaluate(results, return_residuals=True)[3]
- predictions: list
Return value from data_set.get_predictions(raw_predictions, return_reference=True)[0]
- modified_referencelist
Return value from data_set.get_predictions(raw_predictions, return_reference=True)[1]
- lossa LossFunction or str
The type of loss function that was used to calculate total_loss
- calculate_reference(job_collection: JobCollection, data_set: DataSet, engine_settings, overwrite=False, use_pipe=True, folder=None, parallel=None, use_origin=False)¶
Method to calculate and set the reference values for the entries in
data_set
. This method will change the data_set!The method does not modify the DataSetEvaluator instance.
- engine_settingsSettings or EngineCollection
If a Settings instance, that will define the reference engine used to calculate all the jobs.
If an EngineCollection, every job in the job_collection must have a ReferenceEngineID (reference_engine), that is present in the EngineCollection. The settings will then be taken from the engine collection. If more than one engine is needed to evaluate the jobs, then you must pass in an EngineCollection.
- overwritebool
If False, only calculate reference values for data set entries that have no reference value. If True, calculate all reference values.
- use_originbool
If a job in the job_collection has the “Origin” metadata pointing to an ams.rkf results file on disk, then load results from that file instead of rerunning the job.
If both the “Origin” and “Frame” metadata keys exist, data will be taken from the correct frame in the trajectory.
If the Origin, Frame, and OriginalEnergyHartree metadata keys exist, then the energy will be taken from the OriginalEnergyHartree metadata if the ams.rkf in Origin cannot be loaded (for example if it exists on a different machine).
If loading data from the “Origin” or “OriginalEnergyHartree” fails, the job will be run.
job_collection
,data_set
,use_pipe
,folder
, andparallel
have the same meaning as in therun()
method.
- run(job_collection: JobCollection, data_set: DataSet, engine_settings: Settings, loss='sse', use_pipe=True, folder=None, parallel=None, group_by=None)¶
Runs the jobs in the job collection using the engine defined by engine_settings, and evaluates the data_set expressions.
- job_collectionJobCollection
The job collection containing the jobs
- data_setDataSet
The data_set containing the expressions to be evaluated
- engine_settingsSettings
The engine settings to be used. Example:
>>> engine_settings = Settings() >>> engine_settings.input.ForceField.Model = 'UFF'
- lossstr or Loss
The type of loss function
- use_pipebool
Whether to use the pipe interface if possible. This will speed up the calculation. Cannot be combined with folder.
- folderstr
If folder is not None, the results will be stored on disk in that folder. If the folder already exists, a new one is created. If set, will automatically disable the pipe interface.
- parallelParallelLevels
Defaults to ParallelLevels(parametervectors=1, processes=1, threads=1). This will run N jobs in parallel, where N is the number of cores on the machine.
- group_bytuple of str
Group results according to the tuple. The grouping can also be changed after the run with the group_by() method.
- group_by(group_by)¶
Group the results according to
group_by
. Therun()
method needs to called before calling this method.- group_bytuple of str
>>> group_by(('Extractor')) # group by extractor >>> group_by(('Extractor', 'Expression')) # group by extractor, then expression. The expression will be filtered >>> group_by(('Group', 'SubGroup')) # group by the metadata key Group, then by the metadata key SubGroup
- __str__()¶
Return str(self).