OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
_CustomTransformer.py
Go to the documentation of this file.
1 from sklearn.base import TransformerMixin
2 
3 
4 class _CustomTransformer(TransformerMixin):
5  ''' Data transformer class which validates data shapes.
6  Child classes should override _fit, _transform, _inverse_transform '''
7  _input_shape = None
8  _output_shape = None
9 
10  def fit(self, X, *args, **kwargs):
11  self._input_shape = X.shape[1]
12  self._fit(X.copy(), *args, **kwargs)
13  return self
14 
15  def transform(self, X, *args, **kwargs):
16  self._validate_shape(X, self._input_shape)
17  X = self._transform(X.copy(), *args, **kwargs)
18  self._validate_shape(X, self._output_shape)
19  self._output_shape = X.shape[1]
20  return X
21 
22  def inverse_transform(self, X, *args, **kwargs):
23  self._validate_shape(X, self._output_shape)
24  X = self._inverse_transform(X.copy(), *args, **kwargs)
25  self._validate_shape(X, self._input_shape)
26  self._input_shape = X.shape[1]
27  return X
28 
29  def fit_transform(self, X, *args, **kwargs):
30  self.fit(X, *args, **kwargs)
31  return self.transform(X, *args, **kwargs)
32 
33  @staticmethod
34  def config_info(*args, **kwargs): return '' # Return any additional info to construct model config
35  def _fit(self, X, *args, **kwargs): pass
36  def _transform(self, X, *args, **kwargs): raise NotImplemented
37  def _inverse_transform(self, X, *args, **kwargs): raise NotImplemented
38  def _validate_shape(self, X, shape): assert(shape is None or X.shape[1] == shape), \
39  f'Number of data features changed: expected {shape}, found {X.shape[1]}'