Due to the lapse in federal government funding, NASA is not updating this website. We sincerely regret this inconvenience.
NASA Logo
Ocean Color Science Software

ocssw V2022
callbacks.py
Go to the documentation of this file.
1 from .TrainingPlot import TrainingPlot
2 from ..metrics import mdsa, sspb
3 
4 from tempfile import TemporaryDirectory
5 from pathlib import Path
6 
7 import tensorflow as tf
8 import numpy as np
9 
10 
11 class PlottingCallback(tf.keras.callbacks.Callback):
12  ''' Display a real-time training progress plot '''
13 
14  def __init__(self, args, data, model):
15  super(PlottingCallback, self).__init__()
16  self._step_count = 0
17  self.args = args
18  self.TP = TrainingPlot(args, model, data)
19  self.TP.setup()
20 
21  def on_train_batch_end(self, batch, logs=None):
22  self._step_count += 1
23  if (self._step_count % (self.args.n_iter // self.args.n_redraws)) == 0:
24  self.TP.update()
25 
26  def on_train_end(self, *args, **kwargs):
27  self.TP.finish()
28 
29 
30 
31 class StatsCallback(tf.keras.callbacks.Callback):
32  ''' Save performance statistics as the model is trained '''
33 
34  def __init__(self, args, data, mdn, metrics=[mdsa, sspb], folder='Results_gpu'):
35  super(StatsCallback, self).__init__()
36  self._step_count = 0
37  self.start_time = time.time()
38  self.args = args
39  self.data = data
40  self.mdn = mdn
41  self.metrics = metrics
42  self.folder = folder
43 
44  def on_train_batch_end(self, batch, logs=None):
45  if (self._step_count % (self.args.n_iter // self.args.n_redraws)) == 0:
46  all_keys = sorted(self.data.keys())
47  all_data = [self.data[k]['x'] for k in all_keys]
48  all_sums = np.cumsum(list(map(len, [[]] + all_data[:-1])))
49  all_idxs = [slice(c, len(d)+c) for c,d in zip(all_sums, all_data)]
50  all_data = np.vstack(all_data)
51 
52  # Create all estimates, transform back into original units, then split back into the original datasets
53  estimates = self.mdn.predict(all_data)
54  estimates = {k: estimates[idxs] for k, idxs in zip(all_keys, all_idxs)}
55  assert(all([estimates[k].shape == self.data[k]['y'].shape for k in all_keys])), \
56  [(estimates[k].shape, self.data[k]['y'].shape) for k in all_keys]
57 
58  save_folder = Path(self.folder, self.args.config_name).resolve()
59  if not save_folder.exists():
60  print(f'\nSaving training results at {save_folder}\n')
61  save_folder.mkdir(parents=True, exist_ok=True)
62 
63  # Save overall dataset statistics
64  round_stats_file = save_folder.joinpath(f'round_{self.args.curr_round}.csv')
65  if not round_stats_file.exists() or self._step_count == 0:
66  with round_stats_file.open('w+') as fn:
67  fn.write(','.join(['iteration','cumulative_time'] + [f'{k}_{m.__name__}' for k in all_keys for m in self.metrics]) + '\n')
68 
69  stats = [[str(m(y1, y2)) for y1,y2 in zip(self.data[k]['y'].T, estimates[k].T)] for k in all_keys for m in self.metrics]
70  stats = ','.join([f'[{s}]' for s in [','.join(stat) for stat in stats]])
71  with round_stats_file.open('a+') as fn:
72  fn.write(f'{self._step_count},{time.time()-self.start_time},{stats}\n')
73 
74  # Save model estimates
75  save_folder = save_folder.joinpath('Estimates')
76  if not save_folder.exists():
77  save_folder.mkdir(parents=True, exist_ok=True)
78 
79  for k in all_keys:
80  filename = save_folder.joinpath(f'round_{self.args.curr_round}_{k}.csv')
81  if not filename.exists():
82  with filename.open('w+') as fn:
83  fn.write(f'target,{list(self.data[k]["y"][:,0])}\n')
84 
85  with filename.open('a+') as fn:
86  fn.write(f'{self._step_count},{list(estimates[k][:,0])}\n')
87  self._step_count += 1
88 
89 
90 
91 class ModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
92  ''' Save models during training, and load the best performing
93  on the validation set once training is completed.
94  Currently untested.
95  '''
96 
97  def __init__(self, path):
98  Path(path).mkdir(exist_ok=True, parents=True)
99  self.tmp_folder = TemporaryDirectory(dir=path)
100  self.checkpoint = Path(self.tmp_folder.name).joinpath('checkpoint')
101  super(ModelCheckpoint, self).__init__(
102  filepath=self.checkpoint, save_weights_only=True,
103  monitor='val_MSA', mode='min', save_best_only=True) # need to add to metrics
104 
105  def on_train_end(self, *args, **kwargs):
106  self.model.load_weights(self.checkpoint)
107  self.tmp_folder.cleanup()
108 
109 
110 
111 class DecayHistory(tf.keras.callbacks.Callback):
112  ''' Verify tf parameters are being decayed as they should;
113  call show_plot() on object once training is completed '''
114 
115  def on_train_begin(self, logs={}):
116  self.lr = []
117  self.wd = []
118 
119  def on_batch_end(self, batch, logs={}):
120  self.lr.append(self.model.optimizer.lr)
121  self.wd.append(self.model.optimizer.weight_decay)
122 
123  def show_plot(self):
124  import matplotlib.pyplot as plt
125  plt.plot(self.lr, label='learning rate')
126  plt.plot(self.wd, label='weight decay')
127  plt.xlabel('step')
128  plt.ylabel('param value')
129  plt.legend()
130  plt.show()
Definition: setup.py:1
def __init__(self, args, data, model)
Definition: callbacks.py:14
def on_train_batch_end(self, batch, logs=None)
Definition: callbacks.py:44
def on_train_end(self, *args, **kwargs)
Definition: callbacks.py:105
def on_train_begin(self, logs={})
Definition: callbacks.py:115
def on_batch_end(self, batch, logs={})
Definition: callbacks.py:119
def on_train_batch_end(self, batch, logs=None)
Definition: callbacks.py:21
list(APPEND LIBS ${NETCDF_LIBRARIES}) find_package(GSL REQUIRED) include_directories($
Definition: CMakeLists.txt:8
void print(std::ostream &stream, const char *format)
Definition: PrintDebug.hpp:38
def on_train_end(self, *args, **kwargs)
Definition: callbacks.py:26
def __init__(self, args, data, mdn, metrics=[mdsa, sspb], folder='Results_gpu')
Definition: callbacks.py:34
Definition: aerosol.c:136