jeudi 9 novembre 2023

State pattern - determining the next state from multiple options

I am implementing a noise gate in Python, using the state design pattern.

My implementation takes an array of audio samples and, using the parameters of the noise gate, the audio sample magnitude values, and the state of the noise gate, determines a coefficient value in the range [0, 1] which should be multiplied with the current audio sample value.

The states I have defined are OpenState, ClosedState, OpeningState and ClosingState. I believe the image below contains all of the state transitions I need to consider.

State transitions

When the gate is in ClosingState, there are two possible transitions:

  1. ClosingState -> ClosedState - this occurs if the release period elapses without another peak exceeding the threshold during that time.
  2. ClosingState -> OpenState - this occurs if a peak exceeds the threshold at some point during the release period.

The part of my code that decides which state to transition to is this method inside the ClosingState class.

def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:        
    '''
    There are two possible states that we can transition to from ClosingState.
    Feels strange to introduce conditionals to determine state transition(?)
    '''
    # This doesn't feel right introducing these conditionals here.
    if sample_mag > self.context.lin_thresh:
        self.transition_pending = True
        self.new_state = OpenState()
        return True
    
    if self.sample_counter >= self.context.release_period_in_samples-1:
        self.transition_pending = True
        self.new_state = ClosedState()
        return True

My question is simply whether it is OK to use these conditionals to determine which state to transition to. It feels like re-introducing the type of code that using the state pattern gets rid of, but an alternative is not obvious to me.


Below is a minimum example. This may not be needed for my conceptual question above, but I am including it in case it is. The sample audio file can be found here.

SO_ramp_functions.py

import numpy as np

def ramp_linear_increase(num_points):
    ''' Function defining a linear increase from 0 to 1 in num_points samples '''
    return np.linspace(0, 1, num_points)

def ramp_linear_decrease(num_points):
    ''' Function defining a linear decrease from 1 to 0 in num_points samples '''
    return np.linspace(1, 0, num_points)

def ramp_poly_increase(num_points):
    ''' Generate an array of coefficient values for the attack period '''
    x = np.arange(num_points, 0, -1)
    attack_coef_arr = 1 - (x/num_points)**4
    
    # Make sure the start and end are 0 and 1, respectively
    attack_coef_arr[0] = 0
    attack_coef_arr[-1] = 1
    
    return attack_coef_arr


def ramp_poly_decrease(num_points):
    ''' Generate an array of coefficient values for the release period '''
    x = np.arange(num_points)
    release_coef_arr = 1 - (x/num_points)**4
    
    # Make sure the start and end are 1 and 0, respectively
    release_coef_arr[0] = 1
    release_coef_arr[-1] = 0
    
    return release_coef_arr

SO_gate_states.py

from abc import ABC, abstractmethod


class State(ABC):
    """
    The base State class declares methods that all concrete States should
    implement and also provides a backreference to the Context object,
    associated with the State. This backreference can be used by States to
    transition the Context to another State.
    """

    @property
    def context(self):
        return self._context


    @context.setter
    def context(self, context) -> None:
        self._context = context


    @abstractmethod
    def get_sample_coefficient(self, sample_mag: float) -> float:
        pass
    
    
    @abstractmethod
    def check_if_state_transition_is_due(self, sample_mag: float=None) -> None:
        pass
    
    
    @abstractmethod
    def on_entry(self):
        pass
    
    
    @abstractmethod
    def on_exit(self):
        pass


"""
Concrete States implement various behaviors, associated with a state of the
Context.
"""

class ClosedState(State):
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        ''' 
        Get the appropriate coefficient value to multiply with the current
        audio sample value.
        
        In the closed state, the coefficient is always 0.0.
        '''
        
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return 0.0
        
    
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        '''
        Check if a condition is met that initiates a transition.
        For ClosedState, we want to check if the sample magnitude exceeds the threshold.
        '''
        return sample_mag > self.context.lin_thresh
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        pass
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(OpeningState())


class OpeningState(State):
    '''
    - In OpeningState, the coefficient is determined by the shape of the
        specified attack ramp.
    
    - The only state we can transition to from OpeningState is OpenState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        
        self.transition_pending = self.check_if_state_transition_is_due()
        if self.transition_pending:
            return 1.0
        else:
            # Get a value from the gate's attack ramp
            return self.context.attack_ramp[self.sample_counter]
        
        
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        # Transition to OpenState occurs once attack period has elapsed
        return self.sample_counter >= self.context.attack_period_in_samples
    
    
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(OpenState())
            self.on_exit()
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0
    

class OpenState(State):
    '''
    In OpenState, the coefficient is always 1.0.
    The only state we can transition to from OpenState is ClosingState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return 1.0
    
    
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        # The gate can't transition before its hold period has elapsed
        if self.sample_counter < self.context.hold_period_in_samples:
            return False
        else:
            # If the signal magnitude falls below the threshold, we want to
            # transition to ClosingState.
            return sample_mag < self.context.lin_thresh
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(ClosingState())
            self.on_exit()
    

class ClosingState(State):
    '''    
    - The coefficient is determined by the shape of the specified release ramp.
    - The state can transition to either ClosedState or OpenState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
        self.new_state = None
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return self.context.release_ramp[self.sample_counter]
        
        
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:        
        '''
        There are two possible states that we can transition to from ClosingState.
        Feels strange to introduce conditionals to determine state transition(?)
        '''
        # This doesn't feel right introducing these conditionals here.
        if sample_mag > self.context.lin_thresh:
            self.transition_pending = True
            self.new_state = OpenState()
            return True

        if self.sample_counter >= self.context.release_period_in_samples-1:
            self.transition_pending = True
            self.new_state = ClosedState()
            return True
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(self.new_state)
            self.on_exit()
        
        
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0


SO_noise_gate_state_pattern.py

import numpy as np
import SO_ramp_functions as rf

'''
The original template code is found here:
    https://refactoring.guru/design-patterns/state/python/example
'''

class AudioConfig:
    '''
    Values that configure audio playback, so they can be set indepdendently
    of, and shared between, different objects that need them.
    '''
    def __init__(self, fs):
        self.fs = fs


class Context:
    """
    This class represents the noise gate.
    
    The Context defines the interface of interest to clients. It also maintains
    a reference to an instance of a State subclass, which represents the current
    state of the Context.
    """

    def __init__(self, audio_config, state) -> None:
        self.audio_config = audio_config
        self.transition_to(state)
        
        # Specify an initial threshold value in dBFS
        self.thresh = -20
                
        # Specify attack, hold, release, and lookahead periods in seconds
        self.attack_time = 0.005  # seconds
        self.hold_time = 0.05  # seconds
        self.release_time = 0.1  # seconds
        self.lookahead_time = 0.005 # seconds
        
        # Calculate attack, hold, and release periods in samples
        self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)        
        self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
        self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)
        self.lookahead_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.lookahead_time)
        
        # Define the attack and release multiplier ramps - use strategy pattern?
        self.attack_ramp = rf.ramp_poly_increase(num_points=self.attack_period_in_samples)
        self.release_ramp = rf.ramp_poly_decrease(num_points=self.release_period_in_samples)
        
        # Initialise an attribute to store the processed result
        self.processed_array = None
        self.coef_array = None
        
        # Padding to enable lookahead (a bit of a hack)
        self.lookahead_pad_samples = self.lookahead_period_in_samples#2000
        
        # Attributes for debugging
        self.text_output = []


    def transition_to(self, state):
        """
        The Context allows changing the State object at runtime.
        """

        ##print(f"Context: Transition to {type(state).__name__}")
        self._state = state
        self._state.context = self


    # Setters for gate parameters
    def set_attack_time(self, new_attack_time: float) -> None:
        self.attack_time = new_attack_time
        self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)
    
    
    def set_hold_time(self, new_hold_time: float) -> None:
        self.hold_time = new_hold_time
        self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
        
        
    def set_release_time(self, new_release_time: float) -> None:
        self.release_time = new_release_time
        self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)


    @property
    def thresh(self) -> int:
        return self._thresh
    
    
    @thresh.setter
    def thresh(self, new_thresh: int) -> None:
        self._thresh = new_thresh


    @property
    def lin_thresh(self) -> float:
        return self.dBFS_to_lin(self.thresh)
    
    
    # These staticmethods could equally be defined outside the class
    @staticmethod
    def dBFS_to_lin(dBFS_val):
        ''' Helper method to convert a dBFS value to a linear value [0, 1] '''
        return 10 ** (dBFS_val / 20)
        

    @staticmethod
    def seconds_to_samples(fs, seconds_val):
        ''' Helper method to convert a time (seconds) value to a number of samples '''
        return int(fs * seconds_val)

    
    def process_audio_block(self, audio_array=None):
        '''
        Process an array of audio samples according to the gate's parameters,
        current state, and the sample values in the audio array.
        This implementation includes lookahead logic.
        
        '''
        
        # Initialise an array of coefficient values of the same length as audio_array
        # Set initial coefficient values outside valid range [0, 1] for easier debugging
        self.coef_array = np.ones(len(audio_array))[:-self.lookahead_pad_samples] * 2
        # Get the magnitude values of the audio array
        self.mag_array = np.abs(audio_array)

        # Iterate through the samples of the mag_arr, updating coef_array values
        for i, sample_mag in enumerate(self.mag_array[:-self.lookahead_pad_samples]):    
            # Get the coefficient value for the current sample, considering a lookahead period
            self.coef_array[i] = self._state.get_sample_coefficient(self.mag_array[i + self.lookahead_period_in_samples])
            # Increment the counter for tracking the samples elapsed in the current state
            self._state.sample_counter += 1
            # Create a log of the state and samples elapsed, for debugging
            self.text_output.append(f"{type(self._state).__name__}. {self._state.sample_counter}. {self.coef_array[i]:.3f}")
            # After processing the current sample, check if a transition is due
            self._state.handle_state_transition()
            
        self.processed_array = self.coef_array * audio_array[:-self.lookahead_pad_samples]

main.py

'''
Driver code for the noise gate using the state pattern.

'''

from SO_noise_gate_state_pattern import AudioConfig, Context
from SO_gate_states import ClosedState
import numpy as np
import audiofile
import matplotlib.pyplot as plt
import time


# Define some helper/test functions
def load_audio(fpath):
    data, fs = audiofile.read(fpath)
    data = data.T
    if len(data.shape) == 2:
        data = data[:,0]    # convert to mono
    return data


def test_gate_coef_values_are_valid(coef_arr):
    print("Testing gate coef_array values")
    assert(np.all([0<=val<=1 for val in coef_arr]))


if __name__ == "__main__":
    
    # The client code.
    # Configure some audio properties
    audio_config = AudioConfig(fs=44100)
    
    # Create a "context" instance (this is like the NoiseGate class)
    context = Context(audio_config, ClosedState())
    
    # Load audio from file
    sig = load_audio(fpath="./snare_test.wav")
    # Zero-pad the audio array to enable lookahead (experimental)
    sig = np.concatenate((sig, np.zeros(context.lookahead_pad_samples)))
    
    # Process the whole array and time it
    start_time = time.perf_counter()
    context.process_audio_block(sig)
    end_time = time.perf_counter()
    print(f"Time taken to process {len(sig)/audio_config.fs:.2f} seconds of audio: {end_time - start_time:.2f} seconds")
    
    # Some testing on the result
    test_gate_coef_values_are_valid(context.coef_array)
    
    # Plot the result
    plt.plot(context.mag_array, color='blue', linewidth=1, label='signal magnitude')
    plt.plot(context.coef_array, color='green', label='gate coefficient')
    plt.plot(np.abs(context.processed_array), color='orange', label='gate output')
    plt.axhline(context.lin_thresh, color='black', linewidth=1, label='gate threshold')
    plt.legend()
    plt.show()

Aucun commentaire:

Enregistrer un commentaire