OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
plot_utils.py
Go to the documentation of this file.
1 from .metrics import slope, sspb, mdsa, rmsle
2 from .meta import get_sensor_label
3 from .utils import closest_wavelength, ignore_warnings
4 from collections import defaultdict as dd
5 from pathlib import Path
6 import numpy as np
7 
8 
9 def add_identity(ax, *line_args, **line_kwargs):
10  '''
11  Add 1 to 1 diagonal line to a plot.
12  https://stackoverflow.com/questions/22104256/does-matplotlib-have-a-function-for-drawing-diagonal-lines-in-axis-coordinates
13 
14  Usage: add_identity(plt.gca(), color='k', ls='--')
15  '''
16  line_kwargs['label'] = line_kwargs.get('label', '_nolegend_')
17  identity, = ax.plot([], [], *line_args, **line_kwargs)
18 
19  def callback(axes):
20  low_x, high_x = ax.get_xlim()
21  low_y, high_y = ax.get_ylim()
22  lo = max(low_x, low_y)
23  hi = min(high_x, high_y)
24  identity.set_data([lo, hi], [lo, hi])
25 
26  callback(ax)
27  ax.callbacks.connect('xlim_changed', callback)
28  ax.callbacks.connect('ylim_changed', callback)
29 
30  ann_kwargs = {
31  'transform' : ax.transAxes,
32  'textcoords' : 'offset points',
33  'xycoords' : 'axes fraction',
34  'fontname' : 'monospace',
35  'xytext' : (0,0),
36  'zorder' : 25,
37  'va' : 'top',
38  'ha' : 'left',
39  }
40  ax.annotate(r'$\mathbf{1:1}$', xy=(0.87,0.99), size=11, **ann_kwargs)
41 
42 
43 
44 def _create_metric(metric, y_true, y_est, longest=None, label=None):
45  ''' Create a position-aligned string which shows the performance via a single metric '''
46  if label == None: label = metric.__name__.replace('SSPB', '\\beta').replace('MdSA', '\\varepsilon\\thinspace')#.replace('Slope','S\\thinspace')
47  # if label == None: label = metric.__name__.replace('SSPB', 'Bias').replace('MdSA', 'Error')
48  if longest == None: longest = len(label)
49 
50  ispct = metric.__qualname__ in ['mape', 'sspb', 'mdsa'] # metrics which are percentages
51  diff = longest-len(label.replace('^',''))
52  space = r''.join([r'\ ']*diff + [r'\thinspace']*diff)
53  prec = (1 if abs(metric(y_true, y_est)) < 100 and metric.__name__ not in ['N'] else 0) if ispct or metric.__name__ in ['N'] else 3
54  # prec = 1 if abs(metric(y_true, y_est)) < 100 else 0
55  stat = f'{metric(y_true, y_est):.{prec}f}'
56  perc = r'$\small{\mathsf{\%}}$' if ispct else ''
57  return rf'$\mathtt{{{label}}}{space}:$ {stat}{perc}'
58 
59 def _create_stats(y_true, y_est, metrics, title=None):
60  ''' Create stat box strings for all metrics, assuming there is only a single target feature '''
61  longest = max([len(metric.__name__.replace('SSPB', 'Bias').replace('MdSA', 'Error').replace('^','')) for metric in metrics])
62  statbox = [_create_metric(m, y_true, y_est, longest=longest) for m in metrics]
63 
64  if title is not None:
65  statbox = [rf'$\mathbf{{\underline{{{title}}}}}$'] + statbox
66  return statbox
67 
68 def _create_multi_feature_stats(y_true, y_est, metrics, labels=None):
69  ''' Create stat box strings for a single metric, assuming there are multiple target features '''
70  if labels == None:
71  labels = [f'Feature {i}' for i in range(y_true.shape[1])]
72  assert(len(labels) == y_true.shape[1] == y_est.shape[1]), f'Number of labels does not match number of features: {labels} - {y_true.shape}'
73 
74  title = metrics[0].__name__.replace('SSPB', 'Bias').replace('MdSA', 'Error')
75  longest = max([len(label.replace('^','')) for label in labels])
76  statbox = [_create_metric(metrics[0], y1, y2, longest=longest, label=lbl) for y1, y2, lbl in zip(y_true.T, y_est.T, labels)]
77  statbox = [rf'$\mathbf{{\underline{{{title}}}}}$'] + statbox
78  return statbox
79 
80 def add_stats_box(ax, y_true, y_est, metrics=[mdsa, sspb, slope], bottom_right=False, bottom=False, right=False, x=0.025, y=0.97, fontsize=16, label=None):
81  ''' Add a text box containing a variety of performance statistics, to the given axis '''
82  import matplotlib.pyplot as plt
83  plt.rc('text', usetex=True)
84  plt.rcParams['mathtext.default']='regular'
85 
86  create_box = _create_stats if len(y_true.shape) == 1 or y_true.shape[1] == 1 else _create_multi_feature_stats
87  stats_box = '\n'.join( create_box(y_true, y_est, metrics, label) )
88  ann_kwargs = {
89  'transform' : ax.transAxes,
90  'textcoords' : 'offset points',
91  'xycoords' : 'axes fraction',
92  'fontname' : 'monospace',
93  'xytext' : (0,0),
94  'zorder' : 25,
95  'va' : 'top',
96  'ha' : 'left',
97  'bbox' : {
98  'facecolor' : 'white',
99  'edgecolor' : 'black',
100  'alpha' : 0.7,
101  }
102  }
103 
104  ann = ax.annotate(stats_box, xy=(x,y), size=fontsize, **ann_kwargs)
105 
106  bottom |= bottom_right
107  right |= bottom_right
108 
109  # Switch location to (approximately) the bottom right corner
110  if bottom or right or bottom_right:
111  plt.gcf().canvas.draw()
112  bbox_orig = ann.get_tightbbox(plt.gcf().canvas.renderer).transformed(ax.transAxes.inverted())
113 
114  new_x = bbox_orig.x0
115  new_y = bbox_orig.y1
116  if bottom:
117  new_y = bbox_orig.y1 - bbox_orig.y0 + (1 - y)
118  ann.set_y(new_y)
119  new_y += 0.06
120  if right:
121  new_x = 1 - (bbox_orig.x1 - bbox_orig.x0) + x
122  ann.set_x(new_x)
123  new_x -= 0.04
124  ann.xy = (new_x, new_y)
125  return ann
126 
127 
128 def draw_map(*lonlats, scale=0.2, world=False, us=True, eu=False, labels=[], ax=None, gray=False, res='i', **scatter_kws):
129  ''' Helper function to plot locations on a global map '''
130  import matplotlib.pyplot as plt
131  from matplotlib.transforms import Bbox
132  from mpl_toolkits.axes_grid1.inset_locator import TransformedBbox, BboxPatch, BboxConnector
133  from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, inset_axes
134  from mpl_toolkits.basemap import Basemap
135  from itertools import chain
136 
137  PLOT_WIDTH = 8
138  PLOT_HEIGHT = 6
139 
140  WORLD_MAP = {'cyl': [-90, 85, -180, 180]}
141  US_MAP = {
142  'cyl' : [24, 49, -126, -65],
143  'lcc' : [23, 48, -121, -64],
144  }
145  EU_MAP = {
146  'cyl' : [34, 65, -12, 40],
147  'lcc' : [30.5, 64, -10, 40],
148  }
149 
150  def mark_inset(ax, ax2, m, m2, MAP, loc1=(1, 2), loc2=(3, 4), **kwargs):
151  """
152  https://stackoverflow.com/questions/41610834/basemap-projection-geos-controlling-mark-inset-location
153  Patched mark_inset to work with Basemap.
154  Reason: Basemap converts Geographic (lon/lat) to Map Projection (x/y) coordinates
155 
156  Additionally: set connector locations separately for both axes:
157  loc1 & loc2: tuple defining start and end-locations of connector 1 & 2
158  """
159  axzoom_geoLims = (MAP['cyl'][2:], MAP['cyl'][:2])
160  rect = TransformedBbox(Bbox(np.array(m(*axzoom_geoLims)).T), ax.transData)
161  pp = BboxPatch(rect, fill=False, **kwargs)
162  ax.add_patch(pp)
163  p1 = BboxConnector(ax2.bbox, rect, loc1=loc1[0], loc2=loc1[1], **kwargs)
164  ax2.add_patch(p1)
165  p1.set_clip_on(False)
166  p2 = BboxConnector(ax2.bbox, rect, loc1=loc2[0], loc2=loc2[1], **kwargs)
167  ax2.add_patch(p2)
168  p2.set_clip_on(False)
169  return pp, p1, p2
170 
171 
172  if world:
173  MAP = WORLD_MAP
174  kwargs = {'projection': 'cyl', 'resolution': res}
175  elif us:
176  MAP = US_MAP
177  kwargs = {'projection': 'lcc', 'lat_0':30, 'lon_0':-98, 'resolution': res}#, 'epsg':4269}
178  elif eu:
179  MAP = EU_MAP
180  kwargs = {'projection': 'lcc', 'lat_0':48, 'lon_0':27, 'resolution': res}
181  else:
182  raise Exception('Must plot world, US, or EU')
183 
184  kwargs.update(dict(zip(['llcrnrlat', 'urcrnrlat', 'llcrnrlon', 'urcrnrlon'], MAP['lcc' if 'lcc' in MAP else 'cyl'])))
185  if ax is None: f = plt.figure(figsize=(PLOT_WIDTH, PLOT_HEIGHT), edgecolor='w')
186  m = Basemap(ax=ax, **kwargs)
187  ax = m.ax if m.ax is not None else plt.gca()
188 
189  if not world:
190  m.readshapefile(Path(__file__).parent.joinpath('map_files', 'st99_d00').as_posix(), name='states', drawbounds=True, color='k', linewidth=0.5, zorder=11)
191  m.fillcontinents(color=(0,0,0,0), lake_color='#9abee0', zorder=9)
192  if not gray:
193  m.drawrivers(linewidth=0.2, color='blue', zorder=9)
194  m.drawcountries(color='k', linewidth=0.5)
195  else:
196  m.drawcountries(color='w')
197  # m.bluemarble()
198  if not gray:
199  if us or eu: m.shadedrelief(scale=0.3 if world else 1)
200  else:
201  # m.arcgisimage(service='ESRI_Imagery_World_2D', xpixels = 2000, verbose= True)
202  m.arcgisimage(service='World_Imagery', xpixels = 2000, verbose= True)
203  else:
204  pass
205  # lats = m.drawparallels(np.linspace(MAP[0], MAP[1], 13))
206  # lons = m.drawmeridians(np.linspace(MAP[2], MAP[3], 13))
207 
208  # lat_lines = chain(*(tup[1][0] for tup in lats.items()))
209  # lon_lines = chain(*(tup[1][0] for tup in lons.items()))
210  # all_lines = chain(lat_lines, lon_lines)
211 
212  # for line in all_lines:
213  # line.set(linestyle='-', alpha=0.0, color='w')
214 
215  if labels:
216  colors = ['aqua', 'orangered', 'xkcd:tangerine', 'xkcd:fresh green', 'xkcd:clay', 'magenta', 'xkcd:sky blue', 'xkcd:greyish blue', 'xkcd:goldenrod', ]
217  markers = ['o', '^', 's', '*', 'v', 'X', '.', 'x',]
218  mod_cr = False
219  assert(len(labels) == len(lonlats)), [len(labels), len(lonlats)]
220  for i, (label, lonlat) in enumerate(zip(labels, lonlats)):
221  lonlat = np.atleast_2d(lonlat)
222  if 'color' not in scatter_kws or mod_cr:
223  scatter_kws['color'] = colors[i]
224  scatter_kws['marker'] = markers[i]
225  mod_cr = True
226  ax.scatter(*m(lonlat[:,0], lonlat[:,1]), label=label, zorder=12, **scatter_kws)
227  ax.legend(loc='lower left', prop={'weight':'bold', 'size':8}).set_zorder(20)
228 
229  else:
230  for lonlat in lonlats:
231  if len(lonlat):
232  lonlat = np.atleast_2d(lonlat)
233  s = ax.scatter(*m(lonlat[:,0], lonlat[:,1]), zorder=12, **scatter_kws)
234  # plt.colorbar(s, ax=ax)
235  hide_kwargs = {'axis':'both', 'which':'both'}
236  hide_kwargs.update(dict([(k, False) for k in ['bottom', 'top', 'left', 'right', 'labelleft', 'labelbottom']]))
237  ax.tick_params(**hide_kwargs)
238 
239  for axis in ['top','bottom','left','right']:
240  ax.spines[axis].set_linewidth(1.5)
241  ax.spines[axis].set_zorder(50)
242  # plt.axis('off')
243 
244  if world:
245  size = 0.35
246  if us:
247  loc = (0.25, -0.1) if eu else (0.35, -0.01)
248  ax_ins = inset_axes(ax, width=PLOT_WIDTH*size, height=PLOT_HEIGHT*size, loc='center', bbox_to_anchor=loc, bbox_transform=ax.transAxes, axes_kwargs={'zorder': 5})
249 
250  scatter_kws.update({'s': 6})
251  m2 = draw_map(*lonlats, labels=labels, ax=ax_ins, **scatter_kws)
252 
253  mark_inset(ax, ax_ins, m, m2, US_MAP, loc1=(1,1), loc2=(2,2), edgecolor='grey', zorder=3)
254  mark_inset(ax, ax_ins, m, m2, US_MAP, loc1=[3,3], loc2=[4,4], edgecolor='grey', zorder=0)
255 
256 
257  if eu:
258  ax_ins = inset_axes(ax, width=PLOT_WIDTH*size, height=PLOT_HEIGHT*size, loc='center', bbox_to_anchor=(0.75, -0.05), bbox_transform=ax.transAxes, axes_kwargs={'zorder': 5})
259 
260  scatter_kws.update({'s': 6})
261  m2 = draw_map(*lonlats, us=False, eu=True, labels=labels, ax=ax_ins, **scatter_kws)
262 
263  mark_inset(ax, ax_ins, m, m2, EU_MAP, loc1=(1,1), loc2=(2,2), edgecolor='grey', zorder=3)
264  mark_inset(ax, ax_ins, m, m2, EU_MAP, loc1=[3,3], loc2=[4,4], edgecolor='grey', zorder=0)
265 
266  return m
267 
268 
269 def default_dd(d={}, f=lambda k: k):
270  ''' Helper function to allow defaultdicts whose default value returned is the queried key '''
271 
272  class key_dd(dd):
273  ''' DefaultDict which allows the key as the default value '''
274  def __missing__(self, key):
275  if self.default_factory is None:
276  raise KeyError(key)
277  val = self[key] = self.default_factory(key)
278  return val
279  return key_dd(f, d)
280 
281 
282 @ignore_warnings
283 def plot_scatter(y_test, benchmarks, bands, labels, products, sensor, title=None, methods=None, n_col=3, img_outlbl=''):
284  import matplotlib.patheffects as pe
285  import matplotlib.ticker as ticker
286  import matplotlib.pyplot as plt
287  import seaborn as sns
288 
289  folder = Path('scatter_plots')
290  folder.mkdir(exist_ok=True, parents=True)
291 
292  product_labels = default_dd({
293  'chl' : 'Chl\\textit{a}',
294  'pc' : 'PC',
295  'aph' : '\\textit{a}_{ph}',
296  'tss' : 'TSS',
297  'cdom': '\\textit{a}_{CDOM}',
298  })
299 
300  product_units = default_dd({
301  'chl' : '[mg m^{-3}]',
302  'pc' : '[mg m^{-3}]',
303  'tss' : '[g m^{-3}]',
304  'aph' : '[m^{-1}]',
305  'cdom': '[m^{-1}]',
306  }, lambda k: '')
307 
308  model_labels = default_dd({
309  'MDN' : 'MDN_{A}',
310  })
311 
312  products = [p for i,p in enumerate(np.atleast_1d(products)) if i < y_test.shape[-1]]
313 
314  plt.rc('text', usetex=True)
315  plt.rcParams['mathtext.default']='regular'
316  # plt.rcParams['mathtext.fontset'] = 'stix'
317  # plt.rcParams['font.family'] = 'cm'
318 
319  # Only plot certain bands
320  if len(labels) > 3 and 'chl' not in products:
321  product_bands = {
322  'default' : [443, 482, 561, 655],
323  # 'aph' : [443, 530],
324  }
325 
326  target = [closest_wavelength(w, bands) for w in product_bands.get(products[0], product_bands['default'])]
327  plot_label = [w in target for w in bands]
328  plot_order = ['MDN', 'QAA', 'GIOP']
329  plot_bands = True
330  else:
331  plot_label = [True] * len(labels)
332  plot_order = methods
333  plot_bands = False
334 
335  if plot_order is None:
336  if 'chl' in products and len(products) == 1:
337  benchmarks = benchmarks['chl']
338  if 'MLP' in benchmarks:
339  plot_order = ['MDN', 'MLP', 'SVM', 'XGB', 'KNN', 'OC3']
340  else:
341  plot_order = ['MDN', 'Smith_Blend', 'OC6', 'Mishra_NDCI', 'Gons_2band', 'Gilerson_2band']
342  elif len(products) > 1 and any(k in products for k in ['chl', 'tss', 'cdom']):
343  plot_order = {k:v for k,v in {
344  'chl' : ['MDN', 'Gilerson_2band'],#,'Smith_Blend'],
345  'tss' : ['MDN', 'SOLID'],#, 'Novoa'],
346  'cdom' : ['MDN', 'Ficek'],#, 'QAA_CDOM'],
347  }.items() if k in products}
348  plot_label = [True] * len(plot_order)
349  plot_bands = True
350  n_col = len(plot_order)
351  plot_order = {p: ['MDN'] for p in products}
352  plot_label = [True] * 4
353  labels = [(p,label) for label in labels for p in products if p in label]
354  print('Plotting labels:', [l for i,l in enumerate(labels) if plot_label[i]])
355  assert(len(labels) == y_test.shape[-1]), [len(labels), y_test.shape]
356 
357  # plot_order = [p for p in plot_order if p in benchmarks]
358  fig_size = 5
359  n_col = max(n_col, sum(plot_label))
360  n_row = 1#max(1,int(not plot_bands) + int(0.5 + len(plot_order) / (1 if plot_bands else n_col)) -1)
361  # if isinstance(plot_order, dict): n_row = 2#len(plot_order)
362  # if plot_bands:
363  # n_col, n_row = n_row, n_col
364 
365  fig, axes = plt.subplots(n_row, n_col, figsize=(fig_size*n_col, fig_size*n_row+1))
366  axes = [ax for axs in np.atleast_1d(axes).T for ax in np.atleast_1d(axs)]
367  colors = ['xkcd:sky blue', 'xkcd:tangerine', 'xkcd:fresh green', 'xkcd:greyish blue', 'xkcd:goldenrod', 'xkcd:clay', 'xkcd:bluish purple', 'xkcd:reddish']
368 
369  print('Order:', plot_order)
370  print(f'Plot size: {n_row} x {n_col}')
371  print(labels)
372 
373  curr_idx = 0
374  full_ax = fig.add_subplot(111, frameon=False)
375  full_ax.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False, pad=10)
376 
377  # estimate_label = 'Estimated' #'Satellite-derived'
378  # x_pre = 'Measured'
379  estimate_label = 'Measured' #'Satellite-derived'
380  x_pre = 'Estimated'
381  y_pre = estimate_label.replace('-', '\\textbf{-}')
382  plabel = f'{product_labels[products[0]]} {product_units[products[0]]}'
383  xlabel = fr'$\mathbf{{{x_pre} {plabel}}}$'
384  ylabel = fr'$\mathbf{{{y_pre}}}$'+'' +fr'$\mathbf{{ {plabel}}}$'
385  if not isinstance(plot_order, dict):
386  full_ax.set_xlabel(xlabel.replace(' ', '\ '), fontsize=24, labelpad=10)
387  full_ax.set_ylabel(ylabel.replace(' ', '\ '), fontsize=24, labelpad=10)
388  else:
389  full_ax.set_ylabel(fr'$\mathbf{{{x_pre}}}$'.replace(' ', '\ '), fontsize=24, labelpad=15)
390 
391  s_lbl = title or get_sensor_label(sensor).replace('-',' ')
392  n_pts = len(y_test)
393  title = fr'$\mathbf{{\underline{{\large{{{s_lbl}}}}}}}$' + '\n' + fr'$\small{{\mathit{{N\small{{=}}}}{n_pts}}}$'
394  # full_ax.set_title(title.replace(' ', '\ '), fontsize=24, y=1.06)
395 
396  # if isinstance(plot_order, dict):
397  # full_ax.set_title(fr'$\mathbf{{\underline{{\large{{{s_lbl}}}}}}}$'.replace(' ', '\ '), fontsize=24, y=1.03)
398 
399  for plt_idx, (label, y_true) in enumerate(zip(labels, y_test.T)):
400  if not plot_label[plt_idx]: continue
401 
402  product, title = label
403  plabel = f'{product_labels[product]} {product_units[product]}'
404 
405  for est_idx, est_lbl in enumerate(plot_order[product] if isinstance(plot_order, dict) else plot_order):
406  # if plt_idx >= (len(plot_order[product]) if isinstance(plot_order, dict) else benchmarks[est_lbl].shape[1]): continue
407  if isinstance(plot_order, dict) and est_lbl not in benchmarks[product]:
408  axes[curr_idx].tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
409  axes[curr_idx].axis('off')
410  curr_idx += 1
411  continue
412 
413  y_est = benchmarks[product][est_lbl] if isinstance(plot_order, dict) else benchmarks[est_lbl][..., plt_idx]
414  ax = axes[curr_idx]
415  cidx = int(curr_idx / n_row) #(curr_idx / n_col))# if plot_bands else curr_idx
416  color = colors[cidx]
417 
418  # first_row = curr_idx < n_col
419  first_row = (curr_idx % n_row) == 0
420  # last_row = curr_idx >= ((n_row-1)*n_col)
421  last_row = ((curr_idx+1) % n_row) == 0
422  first_col = (curr_idx % n_col) == 0
423  last_col = ((curr_idx+1) % n_col) == 0
424  print(curr_idx, first_row, last_row, first_col, last_col, est_lbl, product, plabel)
425  y_est_log = np.log10(y_est).flatten()
426  y_true_log = np.log10(y_true).flatten()
427  curr_idx += 1
428 
429  l_kws = {'color': color, 'path_effects': [pe.Stroke(linewidth=4, foreground='k'), pe.Normal()], 'zorder': 22, 'lw': 1}
430  s_kws = {'alpha': 0.4, 'color': color}#, 'edgecolor': 'grey'}
431 
432  if est_lbl == 'MDN':
433  [i.set_linewidth(5) for i in ax.spines.values()]
434  # est_lbl = 'MDN_{A}'
435  # est_lbl = 'MDN-I'
436  else:
437  est_lbl = est_lbl.replace('Mishra_','').replace('Gons_2band', 'Gons').replace('Gilerson_2band', 'GI2B').replace('Smith_','').replace('Cao_XGB','BST')#.replace('Cao_', 'Cao\ ')
438  est_lbl = est_lbl.replace('QAA_CDOM', 'QAA\ CDOM')
439  if product not in ['chl', 'tss', 'cdom', 'pc'] and last_col:
440  ax2 = ax.twinx()
441  ax2.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False, pad=0)
442  ax2.grid(False)
443  ax2.set_yticklabels([])
444  ax2.set_ylabel(fr'$\mathbf{{{bands[plt_idx]:.0f}nm}}$', fontsize=22)
445 
446  minv = -2# if product == 'cdom' else -1 # int(np.nanmin(y_true_log)) - 1 if product != 'aph' else -4
447  maxv = 3#1 if product == 'cdom' else 3 # if product == 'chl' else 1 #int(np.nanmax(y_true_log)) + 1 if product != 'aph' else 1
448  loc = ticker.LinearLocator(numticks=int(round(maxv-minv+1.5)))
449  fmt = ticker.FuncFormatter(lambda i, _: r'$10$\textsuperscript{%i}'%i)
450 
451  ax.set_ylim((minv, maxv))
452  ax.set_xlim((minv, maxv))
453  ax.xaxis.set_major_locator(loc)
454  ax.yaxis.set_major_locator(loc)
455  ax.xaxis.set_major_formatter(fmt)
456  ax.yaxis.set_major_formatter(fmt)
457  fs = 22
458 
459  if not last_row: ax.set_xticklabels([])
460  # elif isinstance(plot_order, dict): ax.set_xlabel(fr'$\mathbf{{{x_pre}}}$'+'' +fr'$\mathbf{{ {plabel}}}$'.replace(' ', '\ '), fontsize=18)
461  # if not first_col: ax.set_yticklabels([])
462  elif isinstance(plot_order, dict):
463  ylabel = fr'$\mathbf{{{y_pre}}}$'+'' +fr'$\mathbf{{ {plabel}}}$' + '\n' + fr'$\small{{\mathit{{N\small{{=}}}}{np.isfinite(y_true_log).sum()}}}$'
464  ax.set_xlabel(ylabel.replace(' ', '\ '), fontsize=fs)
465 
466  valid = np.logical_and(np.isfinite(y_true_log), np.isfinite(y_est_log))
467  if valid.sum():
468  sns.regplot(y_true_log[valid], y_est_log[valid], ax=ax, scatter_kws=s_kws, line_kws=l_kws, fit_reg=True, truncate=False, robust=True, ci=None)
469  kde = sns.kdeplot(y_true_log[valid], y_est_log[valid], shade=False, ax=ax, bw='scott', n_levels=4, legend=False, gridsize=100, color='#555')# color=color)
470  # kde.collections[2].set_alpha(0)
471 
472  invalid = np.logical_and(np.isfinite(y_true_log), ~np.isfinite(y_est_log))
473  if invalid.sum():
474  ax.scatter(y_true_log[invalid], [minv]*(invalid).sum(), color='r', alpha=0.4, label=r'$\mathbf{%s\ invalid}$' % (invalid).sum())
475  ax.legend(loc='lower right', prop={'weight':'bold', 'size': 18})
476 
477  add_identity(ax, ls='--', color='k', zorder=20)
478 
479  if valid.sum():
480  add_stats_box(ax, y_true[valid], y_est[valid], metrics=[mdsa, sspb, slope], fontsize=18)
481 
482  # if first_row or not plot_bands or (isinstance(plot_order, dict) and plot_order[product][est_idx] != 'MDN'):
483  if first_row or not plot_bands or (isinstance(plot_order, dict)):# and plot_order[product][est_idx] != 'MDN'):
484  if est_lbl == 'BST':
485  # ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$'+'\n'+r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$', fontsize=18)
486  ax.set_title(r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$' + '\n' + fr'$\mathbf{{\large{{{est_lbl}}}}}$', fontsize=fs, linespacing=0.95)
487 
488  elif est_lbl == 'Ficek':
489  # ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$'+'\n'+r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$', fontsize=18)
490  ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$' + r'$\small{\textit{\ (et\ al.\ 2011)}}$', fontsize=fs, linespacing=0.95)
491 
492  elif est_lbl == 'Mannino':
493  # ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$'+'\n'+r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$', fontsize=18)
494  ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$' + r'$\small{\textit{\ (et\ al.\ 2008)}}$', fontsize=fs, linespacing=0.95)
495 
496  elif est_lbl == 'Novoa':
497  # ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$'+'\n'+r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$', fontsize=18)
498  ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$' + r'$\small{\textit{\ (et\ al.\ 2017)}}$', fontsize=fs, linespacing=0.95)
499 
500  elif est_lbl == 'GI2B':
501  ax.set_title(fr'$\mathbf{{\large{{Gilerson}}}}$' + r'$\small{\textit{\ (et\ al.\ 2010)}}$', fontsize=fs, linespacing=0.95)
502 
503  elif est_lbl == 'MDN': ax.set_title(fr'$\mathbf{{{est_lbl}\ {product_labels[product]}}}$', fontsize=fs)
504  else: ax.set_title(fr'$\mathbf{{\large{{{est_lbl}}}}}$', fontsize=fs)
505 
506  ax.tick_params(labelsize=fs)
507  ax.grid('on', alpha=0.3)
508 
509  u_label = ",".join([o.split('_')[0] for o in plot_order]) if len(plot_order) < 10 else f'{n_row}x{n_col}'
510  filename = folder.joinpath(f'{img_outlbl}{",".join(products)}_{sensor}_{n_pts}test_{u_label}.png')
511  plt.tight_layout()
512  # plt.subplots_adjust(wspace=0.35)
513  plt.savefig(filename.as_posix(), dpi=100, bbox_inches='tight', pad_inches=0.1,)
514  plt.show()
def draw_map(*lonlats, scale=0.2, world=False, us=True, eu=False, labels=[], ax=None, gray=False, res='i', **scatter_kws)
Definition: plot_utils.py:128
def default_dd(d={}, f=lambda k:k)
Definition: plot_utils.py:269
def plot_scatter(y_test, benchmarks, bands, labels, products, sensor, title=None, methods=None, n_col=3, img_outlbl='')
Definition: plot_utils.py:283
def add_stats_box(ax, y_true, y_est, metrics=[mdsa, sspb, slope], bottom_right=False, bottom=False, right=False, x=0.025, y=0.97, fontsize=16, label=None)
Definition: plot_utils.py:80
def get_sensor_label(sensor)
Definition: meta.py:28
def add_identity(ax, *line_args, **line_kwargs)
Definition: plot_utils.py:9
def closest_wavelength(k, waves, validate=True, tol=5, squeeze=False)
Definition: utils.py:24
#define abs(a)
Definition: misc.h:90