What’s the fuzz about?
The Scikit-learn, or sklearn, library is perhaps the primary reason I use Python. With a common API, adopted by many other libraries, it is possible to build complex machine learning systems that can be integrated in cross validations, grid searches, learning curves and many other diagnostics. Yet, occasionally an sklearn class does something we don’t like. In this case, it is useful to redefine a class method on an instance to be more convenient. In this example, we will look at a typical sklearn transformer, the StandardScaler
. Since this class outputs a numpy ndarray
no matter what it is given, if put in a pipeline the original datatype will be lost.
This can be quite frustrating if you want to do some feature importance analysis, but all your feature names have been thrown out. To fix this, we will change the transform
method of the StandardScaler
so that it return a pandas DataFrame
. Let’s set up a simple DataFrame
and see that the StandardScaler
does to it.
import numpy as np
from pandas import DataFrame
from sklearn.preprocessing import StandardScaler
df = DataFrame(np.array([[1, 0], [2, 1], [0, 1]]),
columns=['a', 'b'],
dtype='float')
df
a | b | |
---|---|---|
0 | 1.0 | 0.0 |
1 | 2.0 | 1.0 |
2 | 0.0 | 1.0 |
Note the dtype
of df
:
type(df)
And here’s what the StandardScaler
will return:
nd = StandardScaler().fit_transform(df)
nd
array([[ 0. , -1.41421356],
[ 1.22474487, 0.70710678],
[-1.22474487, 0.70710678]])
Annoying indeed. To prevent this from happening, the manual solution would be to assign the transformed data back onto df
as new values, like so:
dft = df.copy()
dft.loc[:, :] = StandardScaler().fit_transform(df)
dft
a | b | |
---|---|---|
0 | 0.000000 | -1.414214 |
1 | 1.224745 | 0.707107 |
2 | -1.224745 | 0.707107 |
Voilá! Our data is scaled, and rests in a DataFrame
type(dft)
But… pipelining
Of course, if you want to stick the StandardScaler into a pipeline, then we need to build this trick into the StandardScaler
class. But not wanting to change the sklearn source code, we build our own class as a child of the StandardScaler
class.
Inheritance
The first thing we need to accomplish is to simply rebuild the StandardScaler
class, preferably without having to wade through the source code. Turns our this is embarringly simple. All we have to do is create a class that inherits its methods from the StandardScaler
—essentially, it will be our class that walks the walk of the StandardScaler
. If you are unfamiliar with Python classes, it is probably a good idea to have a look at the documentation.
class foo(StandardScaler):
pass
fnd = foo().fit_transform(df)
fnd
array([[ 0. , -1.41421356],
[ 1.22474487, 0.70710678],
[-1.22474487, 0.70710678]])
Override
Now, what we want to do is to change the transform
method of the StandardScaler
. This itself is pretty easy, if we just wanted it to do something we defined fully:
class foo(StandardScaler):
def transform(self, x):
print(type(x))
foo().fit_transform(df)
Recursive override
But when we simply want to the usual transform
method, and then do something at the end, we run into problem. For if we override transform
and by a method that calls transform
, we create an infinite loop:
class foo(StandardScaler):
def transform(self, x):
z = self.fit_transform(x) # we are calling the same function we are defining! I.e. f = f(f)
return DataFrame(z, index=x.index, columns=x.columns)
try:
foo().fit_transform(df)
except Exception as e:
print('Error: {}'.format(e))
That’s a python flag warning us that it is descending into an infinite loop and aborts.
Hence, what we need to do is to call the generic transform
class method, but not the one defined on a particular instance. For this, we need the built-in super()
function, which is a function for delegating method calls to some class in the instance’s ancestor tree. For our purposes, think of super()
as a generic instance of our parent class. For the interested reader, here’s an accessable intro.
class foo(StandardScaler):
def transform(self, x):
z = super(foo, self).transform(x)
return DataFrame(z, index=x.index, columns=x.columns)
dff = foo().fit_transform(df)
dff
a | b | |
---|---|---|
0 | 0.000000 | -1.414214 |
1 | 1.224745 | 0.707107 |
2 | -1.224745 | 0.707107 |
And we have our DataFrame
again!
type(dff)
Customizing parameters
Because we are inheriting the StandardScaler
, our foo
class behaves in every other respect as the StandardScaler
class. That means we are free to modify any paramters that the StandardScaler
has in our foo
class.
dfs = foo(with_mean=False).fit_transform(df)
dfs
a | b | |
---|---|---|
0 | 1.224745 | 0.00000 |
1 | 2.449490 | 2.12132 |
2 | 0.000000 | 2.12132 |
The one thing our above class prevets us from doing is defining new parameters that we might need for our modification. For instance, we might want to control the dtype
of the DataFrame
. To achieve this, we need to modify the constructor (the __init__
method).
class foo(StandardScaler):
def __init__(self, dtype='float', **kwargs):
# Here, we call __init__ on the StandardScaler
super().__init__(**kwargs)
# Here, we can make our own modifications
self.dtype = dtype
def transform(self, X, y=None):
z = super(foo, self).transform(X.values)
return DataFrame(z, index=X.index, columns=X.columns, dtype=self.dtype)
dfi = foo(dtype=np.int).fit_transform(df)
dfi
a | b | |
---|---|---|
0 | 0 | -1 |
1 | 1 | 0 |
2 | -1 | 0 |
Note that the returned DataFrame
now consists of integers, as opposed to floating points. The above code is the most parsimonious way to define a wrapper, but Scikit-learn does not take kindly to using kwargs
in __init__
. If you plan on using your wrapper with other sklearn object such as pipeline
or any grid search object (any method that will clone an instance of your wrapper), you need to remove **kwargs
from the constructor.
class StandardScalerDf(StandardScaler):
"""DataFrame Wrapper around StandardScaler"""
def __init__(self, copy=True, with_mean=True, with_std=True):
super(StandardScalerDf, self).__init__(copy=copy,
with_mean=with_mean,
with_std=with_std)
def transform(self, X, y=None):
z = super(StandardScalerDf, self).transform(X.values)
return DataFrame(z, index=X.index, columns=X.columns)
The StandardScalerDf
work precisely as the original StandardScaler
, with the one difference that it works on pandas DataFrame
objects instead. Of course, there are many ways to improve this class: ideally we would want a class that accepts both and ndarray
and a DataFrame
, as well as a class that allows us to modify the DataFrame
better. I leave that to you.