OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
MDN.py
Go to the documentation of this file.
1 import os, warnings, logging
2 logging.getLogger("tensorflow").setLevel(logging.ERROR)
3 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
4 
5 from pathlib import Path
6 from tqdm.keras import TqdmCallback
7 from tqdm import trange
8 
9 import numpy as np
10 import tensorflow as tf
11 import tensorflow_probability as tfp
12 
13 from ..transformers import IdentityTransformer
14 from ..utils import read_pkl, store_pkl, ignore_warnings
15 
16 from .callbacks import PlottingCallback, StatsCallback, ModelCheckpoint
17 from .utils import initialize_random_states, ensure_format, get_device
18 from .metrics import MSA
19 
20 
21 class MDN:
22  ''' Mixture Density Network which handles multi-output, full (symmetric) covariance.
23 
24  Parameters
25  ----------
26  n_mix : int, optional (default=5)
27  Number of mixtures used in the gaussian mixture model.
28 
29  hidden : list, optional (default=[100, 100, 100, 100, 100])
30  Number of layers and hidden units per layer in the neural network.
31 
32  lr : float, optional (default=1e-3)
33  Learning rate for the model.
34 
35  l2 : float, optional (default=1e-3)
36  L2 regularization scale for the model weights.
37 
38  n_iter : int, optional (default=1e4)
39  Number of iterations to train the model for
40 
41  batch : int, optional (default=128)
42  Size of the minibatches for stochastic optimization.
43 
44  imputations : int, optional (default=5)
45  Number of samples used in multiple imputation when handling NaN
46  target values during training. More samples results in a higher
47  accuracy for the likelihood estimate, but takes longer and may
48  result in overfitting. Assumption is that any missing data is
49  MAR / MCAR, in order to allow a multiple imputation approach.
50 
51  epsilon : float, optional (default=1e-3)
52  Normalization constant added to diagonal of the covariance matrix.
53 
54  activation : str, optional (default=relu)
55  Activation function applied to hidden layers.
56 
57  scalerx : transformer, optional (default=IdentityTransformer)
58  Transformer which has fit, transform, and inverse_transform methods
59  (i.e. follows the format of sklearn transformers). Scales the x
60  values prior to training / prediction. Stored along with the saved
61  model in order to have consistent inputs to the model.
62 
63  scalery : transformer, optional (default=IdentityTransformer)
64  Transformer which has fit, transform, and inverse_transform methods
65  (i.e. follows the format of sklearn transformers). Scales the y
66  values prior to training, and the output values after prediction.
67  Stored along with the saved model in order to have consistent
68  outputs from the model.
69 
70  model_path : pathlib.Path, optional (default=./Weights)
71  Folder location to store saved models.
72 
73  model_name : str, optional (default=MDN)
74  Name to assign to the model.
75 
76  no_load : bool, optional (default=False)
77  If true, train a new model rather than loading a previously
78  trained one.
79 
80  no_save : bool, optional (default=False)
81  If true, do not save the model when training is completed.
82 
83  seed : int, optional (default=None)
84  Random seed. If set, ensure consistent output.
85 
86  verbose : bool, optional (default=False)
87  If true, print various information while loading / training.
88 
89  debug : bool, optional (default=False)
90  If true, use control flow dependencies to determine where NaN
91  values are entering the model. Model runs slower with this
92  parameter set to true.
93 
94  '''
95  distribution = 'MultivariateNormalTriL'
96 
97  def __init__(self, n_mix=5, hidden=[100]*5, lr=1e-3, l2=1e-3, n_iter=1e4,
98  batch=128, imputations=5, epsilon=1e-3,
99  activation='relu',
100  scalerx=None, scalery=None,
101  model_path='Weights', model_name='MDN',
102  no_load=False, no_save=False,
103  seed=None, verbose=False, debug=False, **kwargs):
104 
105  config = initialize_random_states(seed)
106  config.update({
107  'n_mix' : n_mix,
108  'hidden' : list(np.atleast_1d(hidden)),
109  'lr' : lr,
110  'l2' : l2,
111  'n_iter' : n_iter,
112  'batch' : batch,
113  'imputations' : imputations,
114  'epsilon' : epsilon,
115  'activation' : activation,
116  'scalerx' : scalerx if scalerx is not None else IdentityTransformer(),
117  'scalery' : scalery if scalery is not None else IdentityTransformer(),
118  'model_path' : Path(model_path),
119  'model_name' : model_name,
120  'no_load' : no_load,
121  'no_save' : no_save,
122  'seed' : seed,
123  'verbose' : verbose,
124  'debug' : debug,
125  })
126  self.set_config(config)
127 
128  for k in kwargs:
129  warnings.warn(f'Unused keyword given to MDN: "{k}"', UserWarning)
130 
131 
132  def _predict_chunk(self, X, return_coefs=False, use_gpu=False, **kwargs):
133  ''' Generates estimates for the given set. X may be only a subset of the full
134  data, which speeds up the prediction process and limits memory consumption.
135 
136  use_gpu : bool, optional (default=False)
137  Use the GPU to generate estimates if True, otherwise use the CPU.
138  '''
139  with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
140  model_out = self.model( self.scalerx.transform(ensure_format(X)) )
141  coefs_out = self.get_coefs(model_out)
142  outputs = self.extract_predictions(coefs_out, **kwargs)
143 
144  if return_coefs:
145  return outputs, [c.numpy() for c in coefs_out]
146  return outputs
147 
148 
149  @ignore_warnings
150  def predict(self, X, chunk_size=1e5, return_coefs=False, **kwargs):
151  '''
152  Top level interface to get predictions for a given dataset, which wraps _predict_chunk
153  to generate estimates in smaller chunks. See the docstring of extract_predictions() for
154  a description of other keyword parameters that can be given.
155 
156  chunk_size : int, optional (default=1e5)
157  Controls the size of chunks which are estimated by the model. If None is passed,
158  chunking is not used and the model is given all of the X dataset at once.
159 
160  return_coefs : bool, optional (default=False)
161  If True, return the estimated coefficients (prior, mu, sigma) along with the
162  other requested outputs. Note that rescaling the coefficients using scalerx/y
163  is left up to the user, as calculations involving sigma must be performed in
164  the basis learned by the model.
165  '''
166  chunk_size = int(chunk_size or len(X))
167  partial_coefs = []
168  partial_estim = []
169 
170  for i in trange(0, len(X), chunk_size, disable=not self.verbose):
171  chunk_est, chunk_coef = self._predict_chunk(X[i:i+chunk_size], return_coefs=True, **kwargs)
172  partial_coefs.append(chunk_coef)
173  partial_estim.append( np.array(chunk_est, ndmin=3) )
174 
175  coefs = [np.vstack(c) for c in zip(*partial_coefs)]
176  preds = np.hstack(partial_estim)
177 
178  if return_coefs:
179  return preds, coefs
180  return preds
181 
182 
183  def extract_predictions(self, coefs, confidence_interval=None, threshold=None, avg_est=False):
184  '''
185  Function used to extract model predictions from the given set of
186  coefficients. Users should call the predict() method instead, if
187  predictions from input data are needed.
188 
189  confidence_interval : float, optional (default=None)
190  If a confidence interval value is given, then this function
191  returns (along with the predictions) the upper and lower
192  {confidence_interval*100}% confidence bounds around the prediction.
193 
194  threshold : float, optional (default=None)
195  If set, the model outputs the maximum prior estimate when the prior
196  probability is above this threshold; and outputs the average estimate
197  when below the threshold. Any passed value should be in the range (0, 1],
198  though the sign of the threshold can be negative in order to switch the
199  estimates (i.e. negative threshold would output average estimate when prior
200  is greater than the (absolute) value).
201 
202  avg_est : bool, optional (default=False)
203  If true, model outputs the prior probability weighted mean as the
204  estimate. Otherwise, model outputs the maximum prior estimate.
205  '''
206  assert(confidence_interval is None or (0 < confidence_interval < 1)), 'confidence_interval must be in the range (0,1)'
207  assert(threshold is None or (0 < threshold <= 1)), 'threshold must be in the range (0,1]'
208 
209  target = ('avg' if avg_est else 'top') if threshold is None else 'threshold'
210  output = getattr(self, f'_get_{target}_estimate')(coefs)
211  scale = lambda x: self.scalery.inverse_transform(x.numpy())
212 
213  if confidence_interval is not None:
214  assert(threshold is None), f'Cannot calculate confidence on thresholded estimates'
215  confidence = getattr(self, f'_get_{target}_confidence')(coefs, confidence_interval)
216  upper_bar = output + confidence
217  lower_bar = output - confidence
218  return scale(output), scale(upper_bar), scale(lower_bar)
219  return scale(output)
220 
221 
222  @ignore_warnings
223  def fit(self, X, Y, output_slices=None, **kwargs):
224  with get_device(self.config):
225  checkpoint = self.model_path.joinpath('checkpoint')
226 
227  if checkpoint.exists() and not self.no_load:
228  if self.verbose: print(f'Restoring model weights from {checkpoint}')
229  self.load()
230 
231  elif self.no_load and X is None:
232  raise Exception('Model exists, but no_load is set and no training data was given.')
233 
234  elif X is not None and Y is not None:
235  self.scalerx.fit( ensure_format(X), ensure_format(Y) )
236  self.scalery.fit( ensure_format(Y) )
237 
238  # Gather all data (train, validation, test, ...) into singular object
239  datasets = kwargs['datasets'] = kwargs.get('datasets', {})
240  datasets.update({'train': {'x' : X, 'y': Y}})
241 
242  for key, data in datasets.items():
243  if data['x'] is not None:
244  datasets[key].update({
245  'x_t' : self.scalerx.transform( ensure_format(data['x']) ),
246  'y_t' : self.scalery.transform( ensure_format(data['y']) ),
247  })
248  assert(np.isfinite(datasets['train']['x_t']).all()), 'NaN values found in X training data'
249 
250  self.update_config({
251  'output_slices' : output_slices or {'': slice(None)},
252  'n_inputs' : datasets['train']['x_t'].shape[1],
253  'n_targets' : datasets['train']['y_t'].shape[1],
254  })
255  self.build()
256 
257  callbacks = []
258  model_kws = {
259  'batch_size' : self.batch,
260  'epochs' : max(1, int(self.n_iter / max(1, len(X) / self.batch))),
261  'verbose' : 0,
262  'callbacks' : callbacks,
263  }
264 
265  if self.verbose:
266  callbacks.append( TqdmCallback(model_kws['epochs'], data_size=len(X), batch_size=self.batch) )
267 
268  if self.debug:
269  callbacks.append( tf.keras.callbacks.TensorBoard(histogram_freq=1, profile_batch=(2,60)) )
270 
271  if 'args' in kwargs:
272 
273  if getattr(kwargs['args'], 'plot_loss', False):
274  callbacks.append( PlottingCallback(kwargs['args'], datasets, self) )
275 
276  if getattr(kwargs['args'], 'save_stats', False):
277  callbacks.append( StatsCallback(kwargs['args'], datasets, self) )
278 
279  if getattr(kwargs['args'], 'best_epoch', False):
280  if 'valid' in datasets and 'x_t' in datasets['valid']:
281  model_kws['validation_data'] = (datasets['valid']['x_t'], datasets['valid']['y_t'])
282  callbacks.append( ModelCheckpoint(self.model_path) )
283 
284  self.model.fit(datasets['train']['x_t'], datasets['train']['y_t'], **model_kws)
285 
286  if not self.no_save:
287  self.save()
288 
289  else:
290  raise Exception(f"No trained model exists at: \n{self.model_path}")
291  return self
292 
293 
294  def build(self):
295  layer_kwargs = {
296  'activation' : self.activation,
297  'kernel_regularizer' : tf.keras.regularizers.l2(self.l2),
298  'bias_regularizer' : tf.keras.regularizers.l2(self.l2),
299  # 'kernel_initializer' : tf.keras.initializers.LecunNormal(),
300  # 'bias_initializer' : tf.keras.initializers.LecunNormal(),
301  }
302  mixture_kwargs = {
303  'n_mix' : self.n_mix,
304  'n_targets' : self.n_targets,
305  'epsilon' : self.epsilon,
306  }
307  mixture_kwargs.update(layer_kwargs)
308 
309  create_layer = lambda inp, out: tf.keras.layers.Dense(out, input_shape=(inp,), **layer_kwargs)
310  model_layers = [create_layer(inp, out) for inp, out in zip([self.n_inputs] + self.hidden[:-1], self.hidden)]
311  output_layer = MixtureLayer(**mixture_kwargs)
312 
313  # Define yscaler.inverse_transform as a tensorflow function, and estimate extraction from outputs
314  # yscaler_a = self.scalery.scalers[-1].min_
315  # yscaler_b = self.scalery.scalers[-1].scale_
316  # inv_scaler = lambda y: tf.math.exp((tf.reshape(y, shape=[-1]) - yscaler_a) / yscaler_b)
317  # extract_est = lambda z: self._get_top_estimate( self._parse_outputs(z) )
318 
319  # model_layers = [tf.keras.Input(shape=(self.batch, self.n_inputs))] + model_layers
320  def debug(x):
321  tf.print(f'\nShape: {tf.shape(x)} Min: {tf.reduce_min(x)} Max: {tf.reduce_max(x)}')
322  tf.print(x)
323  return x
324  # model_layers = [tf.keras.layers.Lambda(debug)] + model_layers
325  optimizer = tf.keras.optimizers.Adam(self.lr)
326  self.model = tf.keras.Sequential(model_layers + [output_layer], name=self.model_name)
327  self.model.compile(loss=self.loss, optimizer=optimizer, metrics=[])#[MSA(extract_est, inv_scaler)])
328 
329 
330  @tf.function
331  def loss(self, y, output):
332  prior, mu, scale = self._parse_outputs(output)
333  dist = getattr(tfp.distributions, self.distribution)(mu, scale)
334  prob = tfp.distributions.Categorical(probs=prior)
335  mix = tfp.distributions.MixtureSameFamily(prob, dist)
336 
337  def impute(mix, y, N):
338  # summation = tf.zeros(tf.shape(y)[0])
339  # imputation = lambda i, s: [i+1, tf.add(s, mix.log_prob(tf.where(tf.math.is_nan(y), mix.sample(), y)))]
340  # return tf.while_loop(lambda i, x: i < N, imputation, (0, summation), maximum_iterations=N, parallel_iterations=N)[1] / N
341  return tf.reduce_mean([
342  mix.log_prob( tf.where(tf.math.is_nan(y), mix.sample(), y) )
343  for _ in range(N)], 0)
344 
345  # Much slower due to cond executing both branches regardless of the conditional
346  likelihood = tf.cond(tf.reduce_any(tf.math.is_nan(y)), lambda: impute(mix, y, self.imputations), lambda: mix.log_prob(y))
347  # likelihood = mix.log_prob(y)
348  return tf.reduce_mean(-likelihood) + tf.add_n([0.] + self.model.losses)
349 
350 
351  def __call__(self, inputs):
352  return self.model(inputs)
353 
354 
355  def get_config(self):
356  return self.config
357 
358 
359  def set_config(self, config, *args, **kwargs):
360  self.config = {}
361  self.update_config(config, *args, **kwargs)
362 
363 
364  def update_config(self, config, keys=None):
365  if keys is not None:
366  config = {k:v for k,v in config.items() if k in keys or k not in self.config}
367 
368  self.config.update(config)
369  for k, v in config.items():
370  setattr(self, k, v)
371 
372 
373  def save(self):
374  self.model_path.mkdir(parents=True, exist_ok=True)
375  store_pkl(self.model_path.joinpath('config.pkl'), self.get_config())
376  self.model.save_weights(self.model_path.joinpath('checkpoint'))
377 
378 
379  def load(self):
380  self.update_config(read_pkl(self.model_path.joinpath('config.pkl')), ['scalerx', 'scalery', 'tf_random', 'np_random'])
381  tf.random.set_global_generator(self.tf_random)
382  if not hasattr(self, 'model'): self.build()
383  self.model.load_weights(self.model_path.joinpath('checkpoint')).expect_partial()
384 
385 
386  def get_coefs(self, output):
387  prior, mu, scale = self._parse_outputs(output)
388  return prior, mu, self._covariance(scale)
389 
390 
391  def _parse_outputs(self, output):
392  prior, mu, scale = tf.split(output, [self.n_mix, self.n_mix * self.n_targets, -1], axis=1)
393  prior = tf.reshape(prior, shape=[-1, self.n_mix])
394  mu = tf.reshape(mu, shape=[-1, self.n_mix, self.n_targets])
395  scale = tf.reshape(scale, shape=[-1, self.n_mix, self.n_targets, self.n_targets])
396  return prior, mu, scale
397 
398 
399  def _covariance(self, scale):
400  return tf.einsum('abij,abjk->abik', tf.transpose(scale, perm=[0,1,3,2]), scale)
401 
402 
403 
404  '''
405  Estimate Generation
406  '''
407  def _calculate_top(self, prior, values):
408  vals, idxs = tf.nn.top_k(prior, k=1)
409  idxs = tf.stack([tf.range(tf.shape(idxs)[0]), tf.reshape(idxs, [-1])], axis=-1)
410  return tf.gather_nd(values, idxs)
411 
412  def _get_top_estimate(self, coefs, **kwargs):
413  prior, mu, _ = coefs
414  return self._calculate_top(prior, mu)
415 
416  def _get_avg_estimate(self, coefs, **kwargs):
417  prior, mu, _ = coefs
418  return tf.reduce_sum(mu * tf.expand_dims(prior, -1), 1)
419 
420  def _get_threshold_estimate(self, coefs, threshold=0.5):
421  top_estimate = self.get_top_estimate(coefs)
422  avg_estimate = self.get_avg_estimate(coefs)
423  prior, _, _ = coefs
424  return tf.compat.v2.where(tf.expand_dims(tf.math.greater(tf.reduce_max(prior, 1) / threshold, tf.math.sign(threshold)), -1), top_estimate, avg_estimate)
425 
426 
427  '''
428  Confidence Estimation
429  '''
430  def _calculate_confidence(self, sigma, level=0.9):
431  # For a given confidence level probability p (0<p<1), and number of dimensions d, rho is the error bar coefficient: rho=sqrt(2)*erfinv(p ** (1/d))
432  # https://faculty.ucmerced.edu/mcarreira-perpinan/papers/cs-99-03.pdf
433  s, u, v = tf.linalg.svd(sigma)
434  rho = 2**0.5 * tf.math.erfinv(level ** (1./self.n_targets))
435  return tf.cast(rho, tf.float32) * 2 * s ** 0.5
436 
437  def _get_top_confidence(self, coefs, level=0.9):
438  prior, mu, sigma = coefs
439  top_sigma = self._calculate_top(prior, sigma)
440  return self._calculate_confidence(top_sigma, level)
441 
442  def _get_avg_confidence(self, coefs, level=0.9):
443  prior, mu, sigma = coefs
444  avg_estim = self.get_avg_estimate(coefs)
445  avg_sigma = tf.reduce_sum(tf.expand_dims(tf.expand_dims(prior, -1), -1) *
446  (sigma + tf.matmul(tf.transpose(mu - tf.expand_dims(avg_estim, 1), (0,2,1)),
447  mu - tf.expand_dims(avg_estim, 1))), axis=1)
448  return self._calculate_confidence(avg_sigma, level)
449 
450 
451 
452 
453 class MixtureLayer(tf.keras.layers.Layer):
454 
455  def __init__(self, n_mix, n_targets, epsilon, **layer_kwargs):
456  super(MixtureLayer, self).__init__()
457  layer_kwargs.pop('activation', None)
458 
459  self.n_mix = n_mix
460  self.n_targets = n_targets
461  self.epsilon = tf.constant(epsilon)
462  self._layer = tf.keras.layers.Dense(self.n_outputs, **layer_kwargs)
463 
464 
465  @property
466  def layer_sizes(self):
467  ''' Sizes of the prior, mu, and (lower triangle) scale matrix outputs '''
468  sizes = [1, self.n_targets, (self.n_targets * (self.n_targets + 1)) // 2]
469  return self.n_mix * np.array(sizes)
470 
471 
472  @property
473  def n_outputs(self):
474  ''' Total output size of the layer object '''
475  return sum(self.layer_sizes)
476 
477 
478  # @tf.function(experimental_compile=True)
479  def call(self, inputs):
480  prior, mu, scale = tf.split(self._layer(inputs), self.layer_sizes, axis=1)
481 
482  prior = tf.nn.softmax(prior, axis=-1) + tf.constant(1e-9)
483  mu = tf.stack(tf.split(mu, self.n_mix, 1), 1)
484  scale = tf.stack(tf.split(scale, self.n_mix, 1), 1)
485  scale = tfp.math.fill_triangular(scale, upper=False)
486  norm = tf.linalg.diag(tf.ones((1, 1, self.n_targets)))
487  sigma = tf.einsum('abij,abjk->abik', scale, tf.transpose(scale, perm=[0,1,3,2]))#, scale)
488  sigma+= self.epsilon * norm
489  scale = tf.linalg.cholesky(sigma)
490 
491  return tf.keras.layers.concatenate([
492  tf.reshape(prior, shape=[-1, self.n_mix]),
493  tf.reshape(mu, shape=[-1, self.n_mix * self.n_targets]),
494  tf.reshape(scale, shape=[-1, self.n_mix * self.n_targets ** 2]),
495  ])
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def initialize_random_states(seed=None)
Definition: utils.py:5
def extract_predictions(self, coefs, confidence_interval=None, threshold=None, avg_est=False)
Definition: MDN.py:183
def n_outputs(self)
Definition: MDN.py:473
def __init__(self, n_mix=5, hidden=[100] *5, lr=1e-3, l2=1e-3, n_iter=1e4, batch=128, imputations=5, epsilon=1e-3, activation='relu', scalerx=None, scalery=None, model_path='Weights', model_name='MDN', no_load=False, no_save=False, seed=None, verbose=False, debug=False, **kwargs)
Definition: MDN.py:97
def set_config(self, config, *args, **kwargs)
Definition: MDN.py:359
def ensure_format(arr)
Definition: utils.py:18
def _calculate_top(self, prior, values)
Definition: MDN.py:407
def loss(self, y, output)
Definition: MDN.py:331
def _covariance(self, scale)
Definition: MDN.py:399
def build(self)
Definition: MDN.py:294
def update_config(self, config, keys=None)
Definition: MDN.py:364
def _parse_outputs(self, output)
Definition: MDN.py:391
def get_config(self)
Definition: MDN.py:355
def _predict_chunk(self, X, return_coefs=False, use_gpu=False, **kwargs)
Definition: MDN.py:132
def __init__(self, n_mix, n_targets, epsilon, **layer_kwargs)
Definition: MDN.py:455
def call(self, inputs)
Definition: MDN.py:479
def predict(self, X, chunk_size=1e5, return_coefs=False, **kwargs)
Definition: MDN.py:150
def store_pkl(filename, output)
Definition: utils.py:134
def _calculate_confidence(self, sigma, level=0.9)
Definition: MDN.py:430
def read_pkl(filename)
Definition: utils.py:140
def __call__(self, inputs)
Definition: MDN.py:351
def get_device(model_config)
Definition: utils.py:24
def layer_sizes(self)
Definition: MDN.py:466
def fit(self, X, Y, output_slices=None, **kwargs)
Definition: MDN.py:223
def load(self)
Definition: MDN.py:379
def get_coefs(self, output)
Definition: MDN.py:386
def save(self)
Definition: MDN.py:373
string distribution
Definition: MDN.py:95