Keras provides accuracy, precision and recall metrics that you can use to evaluate your model, but these metrics can only evaluate the entire y_true
and y_pred
. I want it to evaluate only the subset of the data. y_true[..., 0:20]
in my data contain binary values that I want to evaluate, but y_true[..., 20:40]
contain another kind of data.
So I modified the precision and recall classes to evaluate only on the first 20 channels of my data. I did that by subclassing these metrics and ask them to slice the data before evaluation.
from tensorflow import keras as kr
class SliceBinaryAccuracy(kr.metrics.BinaryAccuracy):
"""Slice data before evaluating accuracy. To be used as Keras metric"""
def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)
def _slice(self, y):
return y[..., : self.channels]
def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)
class SlicePrecision(kr.metrics.Precision):
"""Slice data before evaluating precision. To be used as Keras metric"""
def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)
def _slice(self, y):
return y[..., : self.channels]
def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)
class SliceRecall(kr.metrics.Recall):
"""Slice data before evaluating recall. To be used as Keras metric"""
def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)
def _slice(self, y):
return y[..., : self.channels]
def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)
The way to use the above classes is like this:
model.compile('adam', loss='mse', metrics=[SliceBinaryAccuracy(20), SlicePrecision(20), SliceRecall(20)])
The code works but I found that the code is quite long. I see lots of duplications from these 3 metrics, how do I generalize these classes into a single class or whatever that is the better design? Please give an example code if possible.
Aucun commentaire:
Enregistrer un commentaire