OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
utils.py
Go to the documentation of this file.
1 import numpy as np
2 import tensorflow as tf
3 
4 
5 def initialize_random_states(seed=None):
6  ''' Initialize the numpy and tensorflow random states, setting the tensorflow global random state
7  since most tensorflow methods don't yet pass around a random state appropriately. TF random
8  states might also not play nice with tf.functions:
9  https://www.tensorflow.org/api_docs/python/tf/random/set_global_generator
10  '''
11  np_random = np.random.RandomState(seed)
12  tf_seed = np_random.randint(1e10, dtype=np.int64) # tf can't take None as seed
13  tf_random = tf.random.Generator.from_seed(tf_seed)
14  tf.random.set_global_generator(tf_random)
15  return {'np_random' : np_random, 'tf_random' : tf_random}
16 
17 
18 def ensure_format(arr):
19  ''' Ensure passed array has two dimensions [n_sample, n_feature], and add the n_feature axis if not '''
20  arr = np.array(arr).copy().astype(np.float32)
21  return (arr[:, None] if len(arr.shape) == 1 else arr)
22 
23 
24 def get_device(model_config):
25  ''' Return the tf.device a job should run on. Logic based
26  on e.g. model size may be added in the future.
27  '''
28  gpus = tf.config.list_physical_devices('GPU')
29  cpus = tf.config.list_physical_devices('CPU')
30  name = (gpus+cpus)[0].name.replace('physical_device:', '')
31  return tf.device('/cpu:0')#name)
def initialize_random_states(seed=None)
Definition: utils.py:5
def ensure_format(arr)
Definition: utils.py:18
void copy(double **aout, double **ain, int n)
def get_device(model_config)
Definition: utils.py:24