OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
metrics.py
Go to the documentation of this file.
1 from .utils import ignore_warnings
2 from scipy import stats
3 import numpy as np
4 import functools
5 
6 
7 def validate_shape(func):
8  ''' Decorator to flatten all function input arrays, and ensure shapes are the same '''
9  @functools.wraps(func)
10  def helper(*args, **kwargs):
11  flat = [a.flatten() if hasattr(a, 'flatten') else a for a in args]
12  flat_shp = [a.shape for a in flat if hasattr(a, 'shape')]
13  orig_shp = [a.shape for a in args if hasattr(a, 'shape')]
14  assert(all(flat_shp[0] == s for s in flat_shp)), f'Shapes mismatch in {func.__name__}: {orig_shp}'
15  return func(*flat, **kwargs)
16  return helper
17 
18 
19 def only_finite(func):
20  ''' Decorator to remove samples which are nan in any input array '''
21  @validate_shape
22  @functools.wraps(func)
23  def helper(*args, **kwargs):
24  stacked = np.vstack(args)
25  valid = np.all(np.isfinite(stacked), 0)
26  assert(valid.sum()), f'No valid samples exist for {func.__name__} metric'
27  return func(*stacked[:, valid], **kwargs)
28  return helper
29 
30 
31 def only_positive(func):
32  ''' Decorator to remove samples which are zero/negative in any input array '''
33  @validate_shape
34  @functools.wraps(func)
35  def helper(*args, **kwargs):
36  stacked = np.vstack(args)
37  valid = np.all(stacked > 0, 0)
38  assert(valid.sum()), f'No valid samples exist for {func.__name__} metric'
39  return func(*stacked[:, valid], **kwargs)
40  return helper
41 
42 
43 def label(name):
44  ''' Label a function to aid in printing '''
45  def wrapper(func):
46  func.__name__ = name
47  return ignore_warnings(func)
48  return wrapper
49 
50 
51 # ============================================================================
52 '''
53 When executing a function, decorator order starts with the
54 outermost decorator and works its way down the stack; e.g.
55  @dec1
56  @dec2
57  def foo(): pass
58  def bar(): pass
59 And then foo == dec1(dec2(bar)). So, foo will execute dec1,
60 then dec2, then the original function.
61 
62 Below, in rmsle (for example), we have:
63  rmsle = only_finite( only_positive( label(rmsle) ) )
64 This means only_positive() will get the input arrays only
65 after only_finite() removes any nan samples. As well, both
66 only_positive() and only_finite() will have access to the
67 function __name__ assigned by label().
68 
69 For all functions below, y=true and y_hat=estimate
70 '''
71 
72 
73 @only_finite
74 @label('RMSE')
75 def rmse(y, y_hat):
76  ''' Root Mean Squared Error '''
77  return np.mean((y - y_hat) ** 2) ** .5
78 
79 
80 @only_finite
81 @only_positive
82 @label('RMSLE')
83 def rmsle(y, y_hat):
84  ''' Root Mean Squared Logarithmic Error '''
85  return np.mean(np.abs(np.log(y) - np.log(y_hat)) ** 2) ** 0.5
86 
87 
88 @only_finite
89 @label('NRMSE')
90 def nrmse(y, y_hat):
91  ''' Normalized Root Mean Squared Error '''
92  return ((y - y_hat) ** 2).mean() ** .5 / y.mean()
93 
94 
95 @only_finite
96 @label('MAE')
97 def mae(y, y_hat):
98  ''' Mean Absolute Error '''
99  return np.mean(np.abs(y - y_hat))
100 
101 
102 @only_finite
103 @label('MAPE')
104 def mape(y, y_hat):
105  ''' Mean Absolute Percentage Error '''
106  return 100 * np.mean(np.abs((y - y_hat) / y))
107 
108 
109 @only_finite
110 @label('<=0')
111 def leqz(y, y_hat=None):
112  ''' Less than or equal to zero (y_hat) '''
113  if y_hat is None: y_hat = y
114  return (y_hat <= 0).sum()
115 
116 
117 @validate_shape
118 @label('<=0|NaN')
119 def leqznan(y, y_hat=None):
120  ''' Less than or equal to zero (y_hat) '''
121  if y_hat is None: y_hat = y
122  return np.logical_or(np.isnan(y_hat), y_hat <= 0).sum()
123 
124 
125 @only_finite
126 @only_positive
127 @label('MdSA')
128 def mdsa(y, y_hat):
129  ''' Median Symmetric Accuracy '''
130  # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
131  return 100 * (np.exp(np.median(np.abs(np.log(y_hat / y)))) - 1)
132 
133 
134 @only_finite
135 @only_positive
136 @label('MSA')
137 def msa(y, y_hat):
138  ''' Mean Symmetric Accuracy '''
139  # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
140  return 100 * (np.exp(np.mean(np.abs(np.log(y_hat / y)))) - 1)
141 
142 
143 @only_finite
144 @only_positive
145 @label('SSPB')
146 def sspb(y, y_hat):
147  ''' Symmetric Signed Percentage Bias '''
148  # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
149  M = np.median( np.log(y_hat / y) )
150  return 100 * np.sign(M) * (np.exp(np.abs(M)) - 1)
151 
152 
153 @only_finite
154 @label('Bias')
155 def bias(y, y_hat):
156  ''' Mean Bias '''
157  return np.mean(y_hat - y)
158 
159 
160 @only_finite
161 @only_positive
162 @label('R^2')
163 def r_squared(y, y_hat):
164  ''' Logarithmic R^2 '''
165  slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
166  return r_value**2
167 
168 
169 @only_finite
170 @only_positive
171 @label('Slope')
172 def slope(y, y_hat):
173  ''' Logarithmic slope '''
174  slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
175  return slope_
176 
177 
178 @only_finite
179 @only_positive
180 @label('Intercept')
181 def intercept(y, y_hat):
182  ''' Locarithmic intercept '''
183  slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
184  return intercept_
185 
186 
187 @validate_shape
188 @label('MWR')
189 def mwr(y, y_hat, y_bench):
190  '''
191  Model Win Rate - Percent of samples in which model has a closer
192  estimate than the benchmark.
193  y: true, y_hat: model, y_bench: benchmark
194  '''
195  y_bench[y_bench < 0] = np.nan
196  y_hat[y_hat < 0] = np.nan
197  y[y < 0] = np.nan
198  valid = np.logical_and(np.isfinite(y_hat), np.isfinite(y_bench))
199  diff1 = np.abs(y[valid] - y_hat[valid])
200  diff2 = np.abs(y[valid] - y_bench[valid])
201  stats = np.zeros(len(y))
202  stats[valid] = diff1 < diff2
203  stats[~np.isfinite(y_bench)] = 1
204  stats[~np.isfinite(y_hat)] = 0
205  return stats.sum() / np.isfinite(y).sum()
206 
207 
208 def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False):
209  ''' Return a string containing performance using various metrics.
210  y should be the true value, y_hat the estimated value. '''
211  y = y.flatten()
212  y_hat = y_hat.flatten()
213  try:
214  if csv: return f'{key},'+','.join([f'{f.__name__}:{f(y, y_hat)}' for f in metrics])
215  else: return f'{key:>12} | '+' '.join([f'{f.__name__}: {f(y, y_hat):>6.3f}' for f in metrics])
216  except Exception as e: return f'{key:>12} | Exception: {e}'
def sspb(y, y_hat)
Definition: metrics.py:146
float mean(float *xs, int sample_size)
Definition: numerical.c:81
def only_finite(func)
Definition: metrics.py:19
def intercept(y, y_hat)
Definition: metrics.py:181
def label(name)
Definition: metrics.py:43
def leqznan(y, y_hat=None)
Definition: metrics.py:119
def bias(y, y_hat)
Definition: metrics.py:155
def r_squared(y, y_hat)
Definition: metrics.py:163
def rmse(y, y_hat)
Definition: metrics.py:75
def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False)
Definition: metrics.py:208
def leqz(y, y_hat=None)
Definition: metrics.py:111
subroutine func(x, conec, n, bconecno, bn, units, u, inno, i, outno, o, Input, Targ, p, sqerr)
Definition: ffnet.f:287
def ignore_warnings(func)
Definition: utils.py:16
def rmsle(y, y_hat)
Definition: metrics.py:83
def mdsa(y, y_hat)
Definition: metrics.py:128
def mwr(y, y_hat, y_bench)
Definition: metrics.py:189
def only_positive(func)
Definition: metrics.py:31
def mape(y, y_hat)
Definition: metrics.py:104
def mae(y, y_hat)
Definition: metrics.py:97
def slope(y, y_hat)
Definition: metrics.py:172
def nrmse(y, y_hat)
Definition: metrics.py:90
def validate_shape(func)
Definition: metrics.py:7
def msa(y, y_hat)
Definition: metrics.py:137