mercredi 5 février 2020

How to remove duplication of these keras metric definitions?

Keras provides accuracy, precision and recall metrics that you can use to evaluate your model, but these metrics can only evaluate the entire y_true and y_pred. I want it to evaluate only the subset of the data. y_true[..., 0:20] in my data contain binary values that I want to evaluate, but y_true[..., 20:40] contain another kind of data.

So I modified the precision and recall classes to evaluate only on the first 20 channels of my data. I did that by subclassing these metrics and ask them to slice the data before evaluation.

from tensorflow import keras as kr

class SliceBinaryAccuracy(kr.metrics.BinaryAccuracy):
    """Slice data before evaluating accuracy. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SlicePrecision(kr.metrics.Precision):
    """Slice data before evaluating precision. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SliceRecall(kr.metrics.Recall):
    """Slice data before evaluating recall. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)

The way to use the above classes is like this:

model.compile('adam', loss='mse', metrics=[SliceBinaryAccuracy(20), SlicePrecision(20), SliceRecall(20)])

The code works but I found that the code is quite long. I see lots of duplications from these 3 metrics, how do I generalize these classes into a single class or whatever that is the better design? Please give an example code if possible.

Aucun commentaire:

Enregistrer un commentaire