dimanche 5 avril 2020

Appending structured data to a class attribute in Python

I have a couple of objects I'm using for running numerical simulations. A minimal example is shown below where there are two objects: 1) an Environment object which has two states (x and y) that it simulates stochastically through time; and 2) a Simulation object which manages the simulaton and saves the state of the Environment throughout the simulation.

Within the Simulation object, I want to save the state of the Environment both 1) through time, and 2) across multiple simulations. Through time I can use a defaultdict to save the state variables within a single simulation but across simulations it's not clear to me the best way to save the defaultdicts that have been generated. If I append to a list (without using copy) then the list returns all identical defaultdicts due to the mutability of lists. In the example below I use copy.copy, as the answer here suggests.

Are there approaches that are more "Pythonic"? Would it be better to use an immutable type to store the defaultdicts for each simulation?

import copy
from collections import defaultdict
import numpy as np, pandas as pd
from matplotlib import pyplot as plt


class Environment(object):
    """
    Class representing a random walk of two variables x and y

    Methods
    -------
    start_simulation:   draw values from state variables from priors
    step:               add random noise to state variables
    current_state:      return current state of x and y in a dict

    """
    def __init__(self, mu1, sigma1, mu2, sigma2):
        self.mu1 = mu1
        self.mu2 = mu2
        self.sigma1 = sigma1
        self.sigma2 = sigma2

    def start_simulation(self):
        self.x = self.mu1 + self.sigma1 * np.random.randn()
        self.y = self.mu2 + self.sigma2 * np.random.randn()

    def step(self):
        self.x += self.sigma1 * np.random.randn()
        self.y += self.sigma2 * np.random.randn()

    def current_state(self):
        return({"x": self.x, "y": self.y})


class Simulation(object):
    """
    Class representing a simulation object for handling the Environment object
     and storing data

    Methods
    -------

    start_simulation:   start the simulation; initialise state of the environment
    simulate:           generate n_simulations simulations of n_timesteps time steps each
    save_state:          
    """
    def __init__(self, env, n_timesteps):
        self.env = env
        self.n_timesteps = n_timesteps

        self.data_all = []
        self.data_states = defaultdict(list)

    def start_simulation(self):
        self.timestep = 0
        self.env.start_simulation()

        # Append current data (if non empty)
        if self.data_states:
            self.data_all.append(copy.copy(self.data_states)) # <---------- this step
            # without copy.copy this will return all elements of the list data_all to be the 
            # same default dict at the end of all simulations - lists are mutable

        # Reset data_current
        self.data_states = defaultdict(list)

    def simulate(self, n_simulations):
        """
        Run simulation for n_simulations and n_timesteps timesteps
        """
        self.start_simulation()

        for self.simulation in range(n_simulations):

            self.timestep = 0

            while(self.timestep < self.n_timesteps):
                self.env.step()
                self.save_state(self.env.current_state())
                self.timestep += 1

            self.start_simulation()


    def save_state(self, state):
        """
        Save results to a default dict
        """
        for key, value in state.items():
            self.data_states[key].append(value)


if __name__ == "__main__":

    # Run 7 simulations, each for for 20 time steps
    N_TIME = 20
    N_SIM = 7

    e = Environment(
        mu1 = 1.4, sigma1 = 0.1, 
        mu2 = 2.6, sigma2 = 0.05)

    s = Simulation(env = e, n_timesteps = N_TIME)
    s.simulate(N_SIM)

    # Plot output
    fig, ax = plt.subplots()
    for var, c in zip(["x", "y"], ["#D55E00", "#009E73"]):
        [ax.plot(pd.DataFrame(d)[var], label = var, color = c) for d in s.data_all]
    ax.set_xlabel("Time")
    ax.set_ylabel("Value")
    plt.show()

random walkspython

Aucun commentaire:

Enregistrer un commentaire