OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
utils.py
Go to the documentation of this file.
1 from .meta import get_sensor_bands, ANCILLARY, PERIODIC
2 from .parameters import update, hypers, flags, get_args
3 from .__version__ import __version__
4 
5 from collections import defaultdict as dd
6 from importlib import import_module
7 from datetime import datetime as dt
8 from pathlib import Path
9 from tqdm import trange
10 
11 import pickle as pkl
12 import numpy as np
13 import hashlib, re, warnings, functools, sys, zipfile
14 
15 
16 def ignore_warnings(func):
17  ''' Decorator to silence all warnings (Runtime, User, Deprecation, etc.) '''
18  @functools.wraps(func)
19  def helper(*args, **kwargs):
20  with warnings.catch_warnings():
21  warnings.filterwarnings('ignore')
22  return func(*args, **kwargs)
23  return helper
24 
25 
26 def find_wavelength(k, waves, validate=True, tol=5):
27  ''' Index of closest wavelength '''
28  waves = np.array(waves)
29  w = np.atleast_1d(k)
30  i = np.abs(waves - w[:, None]).argmin(1)
31  assert(not validate or (np.abs(w-waves[i]).max() <= tol)), f'Needed {k}, but closest was {waves[i]} in {waves} ({np.abs(w-waves[i]).max()} > {tol})'
32  return i.reshape(np.array(k).shape)
33 
34 
35 def closest_wavelength(k, waves, validate=True, tol=5):
36  ''' Value of closest wavelength '''
37  waves = np.array(waves)
38  return waves[find_wavelength(k, waves, validate, tol)]
39 
40 
41 def safe_int(v):
42  ''' Parse int if possible, and return None otherwise '''
43  try: return int(v)
44  except: return None
45 
46 
47 def get_wvl(nc_data, key):
48  ''' Get all wavelengths associated with the given key, available within the netcdf '''
49  wvl = [safe_int(v.replace(key, '')) for v in nc_data.variables.keys() if key in v]
50  return np.array(sorted([w for w in wvl if w is not None]))
51 
52 
53 def line_messages(messages, nbars=1):
54  '''
55  Allow multiline message updates via tqdm.
56  Need to call print() after the tqdm loop,
57  equal to the number of messages which were
58  printed via this function (to reset cursor).
59 
60  nbars is the number of tqdm bars the line
61  messages come after.
62 
63  Usage:
64  nbars = 2
65  for i in trange(5):
66  for j in trange(5, leave=False):
67  messages = [i, i/2, i*2]
68  line_messages(messages, nbars)
69  for _ in range(len(messages) + nbars - 1): print()
70  '''
71  for _ in range(nbars): print()
72  for m in messages: print('\033[K' + str(m))
73  sys.stdout.write('\x1b[A'.join([''] * (nbars + len(messages) + 1)))
74 
75 
76 def get_labels(wavelengths, slices, n_out=None):
77  '''
78  Helper to get label for each target output. Assumes
79  that any variable in <slices> which has more than a
80  single slice index, will have an associated wavelength
81  label.
82 
83  Usage:
84  wavelengths = [443, 483, 561, 655]
85  slices = {'bbp':slice(0,4), 'chl':slice(4,5), 'tss':slice(5,6)}
86  n_out = 5
87  labels = get_labels(wavelengths, slices, n_out)
88  # labels -> ['bbp443', 'bbp483', 'bbp561', 'bbp655', 'chl']
89  '''
90  return [k + (f'{wavelengths[i]:.0f}' if (v.stop - v.start) > 1 else '')
91  for k,v in sorted(slices.items(), key=lambda s: s[1].start)
92  for i in range(v.stop - v.start)][:n_out]
93 
94 
95 def compress(path, overwrite=False):
96  ''' Compress a folder into a .zip archive '''
97  if overwrite or not path.with_suffix('.zip').exists():
98  with zipfile.ZipFile(path.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zf:
99  for item in path.rglob('*'):
100  zf.write(item, item.relative_to(path))
101 
102 
103 def uncompress(path, overwrite=False):
104  ''' Uncompress a .zip archive '''
105  if overwrite or not path.exists():
106  if path.with_suffix('.zip').exists():
107  with zipfile.ZipFile(path.with_suffix('.zip'), 'r') as zf:
108  zf.extractall(path)
109 
110 
111 class CustomUnpickler(pkl.Unpickler):
112  ''' Ensure the classes are found, without requiring an import '''
113  _transformers = [p.stem for p in Path(__file__).parent.joinpath('transformers').glob('*Transformer.py')]
114  _warned = False
115 
116  def find_class(self, module, name):
117  # pathlib/pickle doesn't correctly deal with instantiating
118  # a system-specific path on the opposite system (e.g. WindowsPath
119  # on a linux OS). Instead, we just provide the general Path class.
120  if name in ['WindowsPath', 'PosixPath']:
121  return Path
122 
123  elif name in self._transformers:
124  module = Path(__file__).parent.stem
125  imported = import_module(f'{module}.transformers.{name}')
126  return getattr(imported, name)
127 
128  elif name == 'TransformerPipeline':
129  from .transformers import TransformerPipeline
130  return TransformerPipeline
131 
132  return super().find_class(module, name)
133 
134 def store_pkl(filename, output):
135  ''' Helper to write pickle file '''
136  with Path(filename).open('wb') as f:
137  pkl.dump(output, f)
138  return output
139 
140 def read_pkl(filename):
141  ''' Helper to read pickle file '''
142  with Path(filename).open('rb') as f:
143  return CustomUnpickler(f).load()
144 
145 def cache(filename, recache=False):
146  ''' Decorator for caching function outputs '''
147  path = Path(filename)
148 
149  def wrapper(function):
150  def inner(*args, **kwargs):
151  if not recache and path.exists():
152  return read_pkl(path)
153  return store_pkl(path, function(*args, **kwargs))
154  return inner
155  return wrapper
156 
157 
158 def using_feature(args, flag):
159  '''
160  Certain hyperparameter flags have a yet undecided default value,
161  which means there are two possible names: using the feature, or
162  not using it. This method simply combines both into a single
163  boolean signal, which indicates whether to add the feature.
164  For example:
165  use_flag = hasattr(args, 'use_ratio') and args.use_ratio
166  no_flag = hasattr(args, 'no_ratio') and not args.no_ratio
167  signal = use_flag or no_flag # if true, we add ratios
168  becomes
169  signal = using_feature(args, 'ratio') # if true, we add ratios
170  '''
171  flag = flag.replace('use_', '').replace('no_', '')
172  assert(hasattr(args,f'use_{flag}') or hasattr(args, f'no_{flag}')), f'"{flag}" flag not found'
173  return getattr(args, f'use_{flag}', False) or not getattr(args, f'no_{flag}', True)
174 
175 
176 def split_data(x_data, other_data=[], n_train=0.5, n_valid=0, seed=None, shuffle=True):
177  '''
178  Split the given data into training, validation, and testing
179  subsets, randomly shuffling the original data order.
180  '''
181  if not isinstance(other_data, list): other_data = [other_data]
182 
183  data = [d.iloc if hasattr(d, 'iloc') else d for d in [x_data] + other_data]
184  random = np.random.RandomState(seed)
185  idxs = np.arange(len(x_data))
186  if shuffle: random.shuffle(idxs)
187 
188  # Allow both a percent to be passed in, as well as an absolute number
189  if 0 < n_train <= 1: n_train = int(n_train * len(idxs))
190  if 0 < n_valid <= 1: n_valid = int(n_valid * len(idxs))
191  assert((n_train+n_valid) <= len(x_data)), \
192  'Too many training/validation samples requested: {n_train}, {n_valid} ({len(x_data)} available)'
193 
194  train = [d[ idxs[:n_train] ] for d in data]
195  valid = [d[ idxs[n_train:n_valid+n_train] ] for d in data]
196  test = [d[ idxs[n_train+n_valid:] ] for d in data]
197 
198  # Return just the split x_data if no other data was given
199  if len(data) == 1:
200  train = train[0]
201  valid = valid[0]
202  test = test[0]
203 
204  # If no validation data was requested, just return train/test
205  if n_valid == 0:
206  return train, test
207  return train, valid, test
208 
209 
210 @ignore_warnings
211 def mask_land(data, bands, threshold=0.1, verbose=False):
212  ''' Modified Normalized Difference Water Index, or NDVI if 1500nm+ is not available '''
213  green = closest_wavelength(560, bands, validate=False)
214  red = closest_wavelength(700, bands, validate=False)
215  nir = closest_wavelength(900, bands, validate=False)
216  swir = closest_wavelength(1600, bands, validate=False)
217 
218  b1, b2 = (green, swir) if swir > 1500 else (red, nir) if red != nir else (min(bands), max(bands))
219  i1, i2 = find_wavelength(b1, bands), find_wavelength(b2, bands)
220  n_diff = lambda a, b: np.ma.masked_invalid((a-b) / (a+b))
221  if verbose: print(f'Using bands {b1} & {b2} for land masking')
222  return n_diff(data[..., i1], data[..., i2]).filled(fill_value=threshold-1) <= threshold
223 
224 
225 @ignore_warnings
226 def _get_tile_wavelengths(nc_data, key, sensor, allow_neg=True, landmask=False, args=None):
227  ''' Return the Rrs/rhos data within the netcdf file, for wavelengths of the given sensor '''
228  has_key = lambda k: any([k in v for v in nc_data.variables])
229  wvl_key = f'{key}_' if has_key(f'{key}_') or key != 'Rrs' else 'Rw' # Polymer stores Rw=Rrs*pi
230 
231  if has_key(wvl_key):
232  avail = get_wvl(nc_data, wvl_key)
233  bands = [closest_wavelength(b, avail) for b in get_sensor_bands(sensor, args)]
234  div = np.pi if wvl_key == 'Rw' else 1
235  data = np.ma.stack([nc_data[f'{wvl_key}{b}'][:] / div for b in bands], axis=-1)
236 
237  if not allow_neg: data[data <= 0] = np.nan
238  if landmask: data[ mask_land(data, bands) ] = np.nan
239 
240  return bands, data.filled(fill_value=np.nan)
241  return [], np.array([])
242 
243 def get_tile_data(filenames, sensor, allow_neg=True, rhos=False, anc=False, **kwargs):
244  ''' Gather the correct Rrs/rhos bands from a given scene, as well as ancillary features if necessary '''
245  from netCDF4 import Dataset
246 
247  filenames = np.atleast_1d(filenames)
248  features = ['rhos' if rhos else 'Rrs'] + (ANCILLARY if anc or rhos else [])
249  data = {}
250  available = []
251 
252  # Some sensors use different bands for their rhos models
253  if rhos and '-rho' not in sensor: sensor += '-rho'
254 
255  args = get_args(sensor=sensor, **kwargs)
256  for filename in filenames:
257  with Dataset(filename, 'r') as nc_data:
258  if 'geophysical_data' in nc_data.groups.keys():
259  nc_data = nc_data['geophysical_data']
260 
261  for feature in features:
262  if feature not in data:
263  if feature in ['Rrs', 'rhos']:
264  bands, band_data = _get_tile_wavelengths(nc_data, feature, sensor, allow_neg, landmask=rhos, args=args)
265 
266  if len(bands) > 0:
267  assert(len(band_data.shape) == 3), \
268  f'Different shape than expected: {band_data.shape}'
269  data[feature] = band_data
270 
271  elif feature in nc_data.variables:
272  var = nc_data[feature][:]
273  assert(len(var.shape) == 2), f'Different shape than expected: {var.shape}'
274 
275  if feature in PERIODIC:
276  assert(var.min() >= -180 and var.max() <= 180), \
277  f'Need to adjust transformation for variables not within [-180,180]: {feature}=[{var.min()}, {var.max()}]'
278  data[feature] = np.stack([
279  np.sin(2*np.pi*(var+180)/360),
280  np.cos(2*np.pi*(var+180)/360),
281  ], axis=-1)
282  else: data[feature] = var
283 
284  # Time difference should just be 0: we want estimates for the exact time of overpass
285  if 'time_diff' in features:
286  assert(features[0] in data), f'Missing {features[0]} data: {list(data.keys())}'
287  data['time_diff'] = np.zeros_like(data[features[0]][:, :, 0])
288 
289  assert(len(data) == len(features)), f'Missing features: Found {list(data.keys())}, Expecting {features}'
290  return bands, np.dstack([data[f] for f in features])
291 
292 
293 def generate_config(args, create=True, verbose=True):
294  '''
295  Create a config file for the current settings, and store in
296  a folder location determined by certain parameters:
297  MDN/model_loc/sensor/model_lbl/model_uid/config
298  "model_uid" is computed within this function, but a value can
299  also be passed in manually via args.model_uid in order to allow
300  previous MDN versions to run.
301  '''
302  root = Path(__file__).parent.resolve().joinpath(args.model_loc, args.sensor, args.model_lbl)
303 
304  # Can override the model uid in order to allow prior MDN versions to be run
305  if hasattr(args, 'model_uid'):
306  if args.verbose: print(f'Using manually set model uid: {args.model_uid}')
307  return root.joinpath(args.model_uid)
308 
309  # Hash is always dependent upon these values
310  dependents = [getattr(act, 'dest', '') for group in [hypers, update] for act in group._group_actions]
311  dependents+= ['x_scalers', 'y_scalers']
312 
313  # Hash is only partially dependent upon these values, assuming operation changes when using a feature
314  # - 'use_' flags being set cause dependency
315  # - 'no_' flags being set remove dependency
316  # This allows additional flags to be added without breaking prior model compatibility
317  partials = [getattr(act, 'dest', '') for group in [flags] for act in group._group_actions]
318 
319  config = [f'Version: {__version__}', '', 'Dependencies']
320  config+= [''.join(['-']*len(config[-1]))]
321  others = ['', 'Configuration']
322  others+= [''.join(['-']*len(others[-1]))]
323 
324  for k,v in sorted(args.__dict__.items(), key=lambda z: z[0]):
325  if k in ['x_scalers', 'y_scalers']:
326  cinfo = lambda s, sarg, skw: getattr(s, 'config_info', lambda *a, **k: '')(*sarg, **skw)
327  cfmt = lambda *cargs: f' # {cinfo(*cargs)}' if cinfo(*cargs) else ''
328  v = '\n\t' + '\n\t'.join([f'{(s[0].__name__,) + s[1:]}{cfmt(*s)}' for s in v]) # stringify scaler and its arguments
329 
330  if k in partials and using_feature(args, k):
331  config.append(f'{k:<18}: {v}')
332  elif k in dependents: config.append(f'{k:<18}: {v}')
333  else: others.append(f'{k:<18}: {v}')
334 
335  config = '\n'.join(config) # Model is dependent on some arguments, so they change the uid
336  others = '\n'.join(others) # Other arguments are stored for replicability
337  ver_re = r'(Version\: \d+\.\d+)(?:\.\d+\n)' # Match major/minor version within subgroup, patch/dashes within pattern
338  h_str = re.sub(ver_re, r'\1.0\n', config) # Substitute patch version for ".0" to allow patches within the same uid
339  uid = hashlib.sha256(h_str.encode('utf-8')).hexdigest()
340  folder = root.joinpath(uid)
341  c_file = folder.joinpath('config')
342  uncompress(folder) # Unzip the archive if necessary
343 
344  if args.verbose:
345  print(f'Using model path {folder}')
346 
347  if create:
348  folder.mkdir(parents=True, exist_ok=True)
349 
350  if not c_file.exists():
351  with c_file.open('w+') as f:
352  f.write(f'Created: {dt.now()}\n{config}\n{others}')
353  elif not c_file.exists() and verbose:
354  print('\nCould not find config file with the following parameters:')
355  print('\t'+config.replace('\n','\n\t'),'\n')
356  return folder
357 
358 
359 def _load_datasets(keys, locs, wavelengths, allow_missing=False):
360  '''
361  Load data from [<locs>] using <keys> as the columns.
362  Only loads data which has all the bands defined by
363  <wavelengths> (if necessary, e.g. for Rrs or bbp).
364  First key is assumed to be the x_data, remaining keys
365  (if any) are y_data.
366  - allow_missing=True will allow datasets which are missing bands
367  to be included in the returned data
368 
369  Usage:
370  # Here, data/loc/Rrs.csv, data/loc/Rrs_wvl.csv, data/loc/bbp.csv,
371  # and data/chl.csv all exist, with the correct wavelengths available
372  # for Rrs and bbp (which is determined by Rrs_wvl.csv)
373  keys = ['Rrs', 'bbp', '../chl']
374  locs = 'data/loc'
375  wavelengths = [443, 483, 561, 655]
376  _load_datasets(keys, locs, wavelengths) # -> [Rrs443, Rrs483, Rrs561, Rrs665],
377  [bbp443, bbp483, bbp561, bbp655, chl],
378  {'bbp':slice(0,4), 'chl':slice(4,5)}
379  '''
380  def loadtxt(name, loc, required_wvl):
381  ''' Error handling wrapper over np.loadtxt, with the addition of wavelength selection'''
382  dloc = Path(loc).joinpath(f'{name}.csv')
383 
384  # TSS / TSM / SPM are synonymous
385  if 'tss' in name and not dloc.exists():
386  dloc = Path(loc).joinpath(f'{name.replace("tss","tsm")}.csv')
387 
388  if not dloc.exists():
389  dloc = Path(loc).joinpath(f'{name.replace("tsm","spm")}.csv')
390 
391  # CDOM is just an alias for a_cdom(443) or a_g(443)
392  if 'cdom' in name and not dloc.exists():
393  dloc = Path(loc).joinpath('ag.csv')
394  required_wvl = [443]
395 
396  try:
397  required_wvl = np.array(required_wvl).flatten()
398  assert(dloc.exists()), (f'Key {name} does not exist at {loc} ({dloc})')
399 
400  data = np.loadtxt(dloc, delimiter=',', dtype=float if name not in ['../Dataset', '../meta', '../datetime'] else str, comments=None)
401  if len(data.shape) == 1: data = data[:, None]
402 
403  if data.shape[1] > 1 and data.dtype.type is not np.str_:
404 
405  # If we want to get all data, regardless of if bands are available...
406  if allow_missing:
407  new_data = [[np.nan]*len(data)] * len(required_wvl)
408  wvls = np.loadtxt(Path(loc).joinpath(f'{dloc.stem}_wvl.csv'), delimiter=',')[:,None]
409  idxs = np.abs(wvls - np.atleast_2d(required_wvl)).argmin(0)
410  valid = np.abs(wvls - np.atleast_2d(required_wvl)).min(0) < 2
411 
412  for j, (i, v) in enumerate(zip(idxs, valid)):
413  if v: new_data[j] = data[:, i]
414  data = np.array(new_data).T
415  else:
416  data = data[:, get_valid(dloc.stem, loc, required_wvl)]
417 
418  if 'cdom' in name and dloc.stem == 'ag':
419  data = data[:, find_wavelength(443, required_wvl)].flatten()[:, None]
420  return data
421  except Exception as e:
422  if name not in ['Rrs']:# ['../chl', '../tss', '../cdom']:
423  if dloc.exists():
424  print(f'\n\tError fetching {name} from {loc}:\n{e}')
425  return np.array([]).reshape((0,0))
426  raise e
427 
428  def get_valid(name, loc, required_wvl, margin=2):
429  ''' Dataset at <loc> must have all bands in <required_wvl> within <margin>nm '''
430  if 'HYPER' in str(loc): margin=1
431 
432  # First, validate all required wavelengths are within the margin of an available wavelength
433  wvls = np.loadtxt(Path(loc).joinpath(f'{name}_wvl.csv'), delimiter=',')[:,None]
434  check = np.array([np.abs(wvls-w).min() <= margin for w in required_wvl])
435  assert(check.all()), '\n\t\t'.join([
436  f'{name} is missing {(~check).sum()} wavelengths:',
437  f'Needed {required_wvl}', f'Found {wvls.flatten()}',
438  f'Missing {required_wvl[~check]}', ''])
439 
440  # First, validate available wavelengths are within the margin of the required wavelengths
441  valid = np.array([True] * len(required_wvl))
442  if len(wvls) != len(required_wvl):
443  valid = np.abs(wvls - np.atleast_2d(required_wvl)).min(1) <= margin
444  assert(valid.sum() == len(required_wvl)), [wvls[valid].flatten(), required_wvl]
445 
446  # Then, ensure the order of the available wavelengths are the same as the required
447  if not all([w1 == w2 for w1,w2 in zip(wvls[valid], required_wvl)]):
448  valid = [np.abs(wvls.flatten() - w).argmin() for w in required_wvl]
449  assert(len(np.unique(valid)) == len(valid) == len(required_wvl)), [valid, wvls[valid].flatten(), required_wvl]
450  return valid
451 
452  locs = [Path(loc).resolve() for loc in np.atleast_1d(locs)]
453  print('\n-------------------------')
454  print(f'Loading data for sensor {locs[0].parts[-1]}, and targets {[v.replace("../","") for v in keys[1:]]}')
455  if allow_missing:
456  print('Allowing data regardless of whether all bands exist')
457 
458  x_data = []
459  y_data = []
460  l_data = []
461  for loc in locs:
462  try:
463  loc_data = [loadtxt(key, loc, wavelengths) for key in keys]
464  print(f'\tN={len(loc_data[0]):>5} | {loc.parts[-1]} / {loc.parts[-2]} ({[np.isfinite(ld).all(1).sum() if ld.dtype.type is not np.str_ else len(ld) for ld in loc_data[1:]]})')
465  assert(all([len(l) in [len(loc_data[0]), 0] for l in loc_data])), dict(zip(keys, map(np.shape, loc_data)))
466 
467  if all([l.shape[1] == 0 for l in loc_data[(1 if len(loc_data) > 1 else 0):]]):
468  print(f'Skipping dataset {loc}: missing all features')
469  continue
470 
471  x_data += [loc_data.pop(0)]
472  y_data += [loc_data]
473  l_data += list(zip([loc.parent.name] * len(x_data[-1]), np.arange(len(x_data[-1]))))
474 
475  except Exception as e:
476  # assert(0), e
477  # Allow invalid datasets if there are multiple to be fetched
478  print(f'\nError fetching {loc}:\n\t{e}')
479  if len(np.atleast_1d(locs)) == 1:
480  raise e
481 
482  assert(len(x_data) > 0 or len(locs) == 0), 'No datasets are valid with the given wavelengths'
483  assert(all([x.shape[1] == x_data[0].shape[1] for x in x_data])), f'Differing number of {keys[0]} wavelengths: {[x.shape for x in x_data]}'
484 
485  # Determine the number of features each key should have
486  slices = []
487  for i, key in enumerate(keys[1:]):
488  shapes = [y[i].shape[1] for y in y_data]
489  slices.append(max(shapes))
490 
491  for x, y in zip(x_data, y_data):
492  if y[i].shape[1] == 0:
493  y[i] = np.full((x.shape[0], max(shapes)), np.nan)
494  assert(all([y[i].shape[1] == y_data[0][i].shape[1] for y in y_data])), f'{key} shape mismatch: {[y.shape for y in y_data]}'
495 
496  # Drop any missing features
497  drop = []
498  for i, s in enumerate(slices):
499  if s == 0:
500  print(f'Dropping {keys[i+1]}: feature has no samples available')
501  drop.append(i)
502 
503  slices = np.cumsum([0] + [s for i,s in enumerate(slices) if i not in drop])
504  keys = [k for i,k in enumerate(keys[1:]) if i not in drop]
505  for y in y_data:
506  y = [z for i,z in enumerate(y) if i not in drop]
507 
508  # Combine everything together
509  l_data = np.vstack(l_data)
510  x_data = np.vstack(x_data)
511 
512  if len(keys) > 0:
513  y_data = np.vstack([np.hstack(y) for y in y_data])
514  assert(slices[-1] == y_data.shape[1]), [slices, y_data.shape]
515  assert(y_data.shape[0] == x_data.shape[0]), [x_data.shape, y_data.shape]
516  slices = {k.replace('../','') : slice(slices[i], s) for i,(k,s) in enumerate(zip(keys, slices[1:]))}
517  print(f'\tTotal prior to filtering: {len(x_data)}')
518 
519  # Fit exponential function to ad and ag values, and eliminate samples with too much error
520  for product in ['ad', 'ag']:
521  if product in slices:
522  from .metrics import mdsa
523  from scipy.optimize import curve_fit
524 
525  exponential = lambda x, a, b, c: a * np.exp(-b*x) + c
526  remove = np.zeros_like(y_data[:,0]).astype(bool)
527 
528  for i, sample in enumerate(y_data):
529  sample = sample[slices[product]]
530  assert(len(sample) > 5), f'Number of bands should be larger, when fitting exponential: {product}, {sample.shape}'
531  assert(len(sample) == len(wavelengths)), f'Sample size / wavelengths mismatch: {len(sample)} vs {len(wavelengths)}'
532 
533  if np.all(np.isfinite(sample)) and np.min(sample) > -0.1:
534  try:
535  x = np.array(wavelengths) - np.min(wavelengths)
536  params, _ = curve_fit(exponential, x, sample, bounds=((1e-3, 1e-3, 0), (1e2, 1e0, 1e1)))
537  new_sample = exponential(x, *params)
538 
539  # Should be < 10% error between original and fitted exponential
540  if mdsa(sample[None,:], new_sample[None,:]) < 10:
541  y_data[i, slices[product]] = new_sample
542  else: remove[i] = True # Exponential could be fit, but error was too high
543  except: remove[i] = True # Sample deviated so much from a smooth exponential decay that it could not be fit
544  # else: remove[i] = True # NaNs / negatives in the sample
545 
546  # Don't actually drop them yet, in case we are fetching all samples regardless of nan composition
547  x_data[remove] = np.nan
548  y_data[remove] = np.nan
549  l_data[remove] = np.nan
550 
551  if remove.sum():
552  print(f'Removed {remove.sum()} / {len(remove)} samples due to poor quality {product} spectra')
553  assert((~remove).sum()), f'All data removed due to {product} spectra quality...'
554 
555  return x_data, y_data, slices, l_data
556 
557 
558 def _filter_invalid(x_data, y_data, slices, allow_nan_inp=False, allow_nan_out=False, other=[]):
559  '''
560  Filter the given data to only include samples which are valid. By
561  default, valid samples include all which are not nan, and greater
562  than zero (for all target features).
563  - allow_nan_inp=True can be set to allow a sample as valid if _any_
564  of a sample's input x features are not nan and greater than zero.
565  - allow_nan_out=True can be set to allow a sample as valid if _any_
566  of a sample's target y features are not nan and greater than zero.
567  - "other" is an optional set of parameters which will be pruned with the
568  test sets (i.e. passing a list of indices will return the indices which
569  were kept)
570  Multiple data sets can also be passed simultaneously as a list to the
571  respective parameters, in order to filter the same samples out of all
572  data sets (e.g. OLI and S2B data, containing same samples but different
573  bands, can be filtered so they end up with the same samples relative to
574  each other).
575  '''
576 
577  # Allow multiple sets to be given, and align them all to the same sample subset
578  if type(x_data) is not list: x_data = [x_data]
579  if type(y_data) is not list: y_data = [y_data]
580  if type(other) is not list: other = [other]
581 
582  both_data = [x_data, y_data]
583  set_length = [len(fullset) for fullset in both_data]
584  set_shape = [[len(subset) for subset in fullset] for fullset in both_data]
585 
586  assert(np.all([length == len(x_data) for length in set_length])), \
587  f'Mismatching number of subsets: {set_length}'
588  assert(np.all([[shape == len(fullset[0]) for shape in shapes]
589  for shapes, fullset in zip(set_shape, both_data)])), \
590  f'Mismatching number of samples: {set_shape}'
591  assert(len(other) == 0 or all([len(o) == len(x_data[0]) for o in other])), \
592  f'Mismatching number of samples within other data: {[len(o) for o in other]}'
593 
594  # Ensure only positive / finite testing features, but allow the
595  # possibility of some nan values in x_data (if allow_nan_inp is
596  # set) or in y_data (if allow_nan_out is set) - so long as the
597  # sample has other non-nan values in the respective feature set
598  valid = np.ones(len(x_data[0])).astype(np.bool)
599  for i, fullset in enumerate(both_data):
600  for subset in fullset:
601  subset[np.isnan(subset)] = -999.
602  subset[np.logical_or(subset <= 1e-8, not i and (subset >= 10))] = np.nan
603  has_nan = np.any if (i and allow_nan_out) or (not i and allow_nan_inp) else np.all
604  valid = np.logical_and(valid, has_nan(np.isfinite(subset), 1))
605 
606  x_data = [x[valid] for x in x_data]
607  y_data = [y[valid] for y in y_data]
608  print(f'Removed {(~valid).sum()} invalid samples ({valid.sum()} remaining)')
609  assert(valid.sum()), 'All samples have nan or negative values'
610 
611  if len(other) > 0:
612  return x_data, y_data, [np.array(o)[valid] for o in other]
613  return x_data, y_data
614 
615 
616 def get_data(args):
617  ''' Main function for gathering datasets '''
618  np.random.seed(args.seed)
619  sensor = args.sensor.split('-')[0]
620  products = args.product.split(',')
621  bands = get_sensor_bands(args.sensor, args)
622 
623  # Using Hydrolight simulated data
624  if using_feature(args, 'sim'):
625  assert(not using_feature(args, 'ratio')), 'Too much memory needed for simulated+ratios'
626  data_folder = ['790']
627  data_keys = ['Rrs']+products #['Rrs', 'bb_p', 'a_p', '../chl', '../tss', '../cdom']
628  data_path = Path(args.sim_loc)
629 
630  else:
631  if products[0] == 'all':
632  products = ['chl', 'tss', 'cdom', 'ad', 'ag', 'aph']# + ['a*ph', 'apg', 'a']
633 
634  data_folder = []
635  data_keys = ['Rrs']
636  data_path = Path(args.data_loc)
637  get_dataset = lambda path, p: Path(path.as_posix().replace(f'/{sensor}','').replace(f'/{p}.csv','')).stem
638 
639  for product in products:
640  if product in ['chl', 'tss', 'cdom', 'pc']:
641  product = f'../{product}'
642 
643  # Find all datasets with the given product available
644  safe_prod = product.replace('*', '[*]') # Prevent glob from getting confused by wildcard
645  datasets = [get_dataset(path, product) for path in data_path.glob(f'*/{sensor}/{safe_prod}.csv')]
646 
647  if product == 'aph':
648  datasets = [d for d in datasets if d not in ['PACE']]
649 
650  if getattr(args, 'subset', ''):
651  datasets = [d for d in datasets if d in args.subset.split(',')]
652 
653  data_folder += datasets
654  data_keys += [product]
655 
656  # Get only unique entries, while also preserving insertion order
657  order_unique = lambda a: [a[i] for i in sorted(np.unique(a, return_index=True)[1])]
658  data_folder = order_unique(data_folder)
659  data_keys = order_unique(data_keys)
660  assert(len(data_folder)), f'No datasets found for {products} within {data_path}/*/{sensor}'
661  assert(len(data_keys)), f'No variables found for {products} within {data_path}/*/{sensor}'
662 
663  sensor_loc = [data_path.joinpath(f, sensor) for f in data_folder]
664  x_data, y_data, slices, sources = _load_datasets(data_keys, sensor_loc, bands, allow_missing=('-nan' in args.sensor) or (getattr(args, 'align', None) is not None))
665 
666  # Hydrolight simulated CDOM is incorrectly scaled
667  if using_feature(args, 'sim') and 'cdom' in slices:
668  y_data[:, slices['cdom']] *= 0.18
669 
670  # Allow data from one sensor to be aligned with other sensors (so the samples will be the same across sensors)
671  if getattr(args, 'align', None) is not None:
672  assert('-nan' not in args.sensor), 'Cannot allow all samples via "-nan" while also aligning to other sensors'
673  align = args.align.split(',')
674  if 'all' in align:
675  align = [s for s in SENSOR_LABELS.keys() if s != 'HYPER']
676  align_loc = [[data_path.joinpath(f, a.split('-')[0]) for f in data_folder] for a in align]
677 
678  print(f'\nLoading alignment data for {align}...')
679  x_align, y_align, slices_align, sources_align = map(list,
680  zip(*[_load_datasets(data_keys, loc, get_sensor_bands(a, args), allow_missing=True) for a, loc in zip(align, align_loc)]))
681 
682  x_data = [x_data] + x_align
683  y_data = [y_data] + y_align
684 
685  # PC shouldn't be greater than 1000 mg/m^3
686  if 'pc' in slices:
687  above = y_data[..., slices['pc']].flatten() > 1000
688  below = y_data[..., slices['pc']].flatten() < 0.1
689  y_data[above|below, slices['pc']] = np.nan
690 
691  # if -nan IS in the sensor label: do not filter samples; allow all, regardless of nan composition
692  if '-nan' not in args.sensor:
693  (x_data, *_), (y_data, *_), (sources, *_) = _filter_invalid(x_data, y_data, slices, other=[sources], allow_nan_out=not using_feature(args, 'sim') and len(data_keys) > 2)
694 
695  # Correct chl data for pheopigments
696  if 'chl' in args.product and using_feature(args, 'tchlfix'):
697  assert(not using_feature(args, 'sim')), 'Simulated data does not need TChl correction'
698  y_data = _fix_tchl(y_data, sources, slices, data_path)
699 
700  # Minimum Rrs value shouldn't be below ~1e-6
701  # print((x_data < 1e-6).any(-1).sum(), 'samples below threshold')
702  # x_data = np.maximum(1e-6, x_data)
703  # x_data = np.append(x_data[:4866], x_data[4867:], 0)
704  # y_data = np.append(y_data[:4866], y_data[4867:], 0)
705  # sources = np.append(sources[:4866], sources[4867:], 0)
706 
707  print('\nFinal counts:')
708  print('\n'.join([f'\tN={num:>5} | {loc}' for loc, num in zip(*np.unique(sources[:, 0], return_counts=True))]))
709  print(f'\tTotal: {len(sources)}')
710  return x_data, y_data, slices, sources
711 
712 
713 def _fix_tchl(y_data, sources, slices, data_path, debug=False):
714  ''' Very roughly correct chl for pheopigments '''
715  import pandas as pd
716 
717  dataset_name, sample_idx = sources.T
718  sample_idx.astype(int)
719 
720  fix = np.ones(len(y_data)).astype(np.bool)
721  old = y_data.copy()
722 
723  set_idx = np.where(dataset_name == 'Sundar')[0]
724  dataset = np.loadtxt(data_path.joinpath('Sundar', 'Dataset.csv'), delimiter=',', dtype=str)[sample_idx[set_idx]]
725  fix[set_idx[dataset == 'ACIX_Krista']] = False
726  fix[set_idx[dataset == 'ACIX_Moritz']] = False
727 
728  set_idx = np.where(data_lbl == 'SeaBASS2')[0]
729  meta = pd.read_csv(data_path.joinpath('SeaBASS2', 'meta.csv')).iloc[sample_idx[set_idx]]
730  lonlats = meta[['east_longitude', 'west_longitude', 'north_latitude', 'south_latitude']].apply(lambda v: v.apply(lambda v2: v2.split('||')[0]))
731  # assert(lonlats.apply(lambda v: v.apply(lambda v2: v2.split('::')[0] == 'rrs')).all().all()), lonlats[~lonlats.apply(lambda v: v.apply(lambda v2: v2.split('::')[0] == 'rrs')).all(1)]
732 
733  lonlats = lonlats.apply(lambda v: pd.to_numeric(v.apply(lambda v2: v2.split('::')[1].replace('[deg]','')), 'coerce'))
734  lonlats = lonlats[['east_longitude', 'north_latitude']].to_numpy()
735 
736  # Only needs correction in certain areas, and for smaller chl magnitudes
737  fix[set_idx[np.logical_and(lonlats[:,0] < -117, lonlats[:,1] > 32)]] = False
738  fix[y_data[:,0] > 80] = False
739  print(f'Correcting {fix.sum()} / {len(fix)} samples')
740 
741  coef = [0.04, 0.776, 0.015, -0.00046, 0.000004]
742  # coef = [-0.12, 0.9, 0.001]
743  y_data[fix, slices['chl']] = np.sum(np.array(coef) * y_data[fix, slices['chl']] ** np.arange(len(coef)), 1, keepdims=False)
744 
745  if debug:
746  import matplotlib.pyplot as plt
747  from .plot_utils import add_identity
748  plt.scatter(old, y_data)
749  plt.xlabel('Old')
750  plt.ylabel('New')
751  plt.xscale('log')
752  plt.yscale('log')
753  add_identity(plt.gca(), color='k', ls='--')
754  plt.xlim((y_data[y_data > 0].min()/10, y_data.max()*10))
755  plt.ylim((y_data[y_data > 0].min()/10, y_data.max()*10))
756  plt.show()
757  return y_data
758 
def loadtxt(filename, delimiter=',')
Definition: utils.py:9
def uncompress(path, overwrite=False)
Definition: utils.py:103
def mask_land(data, bands, threshold=0.1, verbose=False)
Definition: utils.py:211
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def compress(path, overwrite=False)
Definition: utils.py:95
def split_data(x_data, other_data=[], n_train=0.5, n_valid=0, seed=None, shuffle=True)
Definition: utils.py:176
def find_class(self, module, name)
Definition: utils.py:116
def get_sensor_bands(sensor, args=None)
Definition: meta.py:114
def using_feature(args, flag)
Definition: utils.py:158
def get_wvl(nc_data, key)
Definition: utils.py:47
def get_data(args)
Definition: utils.py:616
def cache(filename, recache=False)
Definition: utils.py:145
def add_identity(ax, *line_args, **line_kwargs)
Definition: plot_utils.py:9
def closest_wavelength(k, waves, validate=True, tol=5)
Definition: utils.py:35
void load(float x1, float v[], float y[])
subroutine func(x, conec, n, bconecno, bn, units, u, inno, i, outno, o, Input, Targ, p, sqerr)
Definition: ffnet.f:287
def store_pkl(filename, output)
Definition: utils.py:134
def get_tile_data(filenames, sensor, allow_neg=True, rhos=False, anc=False, **kwargs)
Definition: utils.py:243
def line_messages(messages, nbars=1)
Definition: utils.py:53
def get_labels(wavelengths, slices, n_out=None)
Definition: utils.py:76
def get_args(kwargs={}, use_cmdline=True, **kwargs2)
Definition: parameters.py:100
def read_pkl(filename)
Definition: utils.py:140
const char * str
Definition: l1c_msi.cpp:35
def ignore_warnings(func)
Definition: utils.py:16
def generate_config(args, create=True, verbose=True)
Definition: utils.py:293
def safe_int(v)
Definition: utils.py:41
def mdsa(y, y_hat)
Definition: metrics.py:128
def find_wavelength(k, waves, validate=True, tol=5)
Definition: utils.py:26