OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
DatasetMembershipTransformer.py
Go to the documentation of this file.
1 from ._CustomTransformer import _CustomTransformer
2 
3 from sklearn import preprocessing
4 import numpy as np
5 
6 
8  ''' Appends a one-hot vector of features to each sample, indicating dataset membership '''
9 
10  def __init__(self, datasets, *args, **kwargs):
11  self.locs = datasets.flatten()[:, None]
12  self.ohe = preprocessing.OneHotEncoder(sparse=False).fit(self.locs)
13  self.n_sets = len(np.unique(self.locs))
14 
15  def _transform(self, X, *args, idx=None, zeros=False, **kwargs):
16  # if X is not the same shape as locs, and idx=None, we just append zeros
17  # otherwise, we append the appropriate one-hot vector corresponding to a
18  # sample's dataset membership
19  if ((X.shape[0] == len(self.locs)) or (idx is not None)) and not zeros:
20  idx = idx or slice(None)
21  loc = self.locs[idx]
22  return np.append(X, self.ohe.transform(loc), 1)
23  return np.append(X, np.zeros((len(X), self.n_sets)), 1)
24 
25  def _inverse_transform(self, X, *args, **kwargs):
26  return X[:, :-self.n_sets]