From f6b2d1ada3a9743024e146bbae32c57413e92aca Mon Sep 17 00:00:00 2001 From: Michael Butler Date: Sat, 19 Dec 2020 14:44:35 -0800 Subject: [PATCH] Relocate ExecutionBurst* classes to NN util code The only changes when copying these files were .clang-format differences and correcting a typo in a comment. Bug: 177267324 Test: mma Change-Id: I96cc2402642e1e3076ac7e78e06163c1d3d41701 Merged-In: I96cc2402642e1e3076ac7e78e06163c1d3d41701 (cherry picked from commit 87e83068784b65ab851e4ff65a1099de4e777c9e) --- neuralnetworks/1.2/utils/Android.bp | 1 + .../nnapi/hal/1.2/ExecutionBurstController.h | 345 ++++++++++ .../nnapi/hal/1.2/ExecutionBurstServer.h | 343 ++++++++++ .../utils/src/ExecutionBurstController.cpp | 631 +++++++++++++++++ .../1.2/utils/src/ExecutionBurstServer.cpp | 646 ++++++++++++++++++ 5 files changed, 1966 insertions(+) create mode 100644 neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h create mode 100644 neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h create mode 100644 neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp create mode 100644 neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp diff --git a/neuralnetworks/1.2/utils/Android.bp b/neuralnetworks/1.2/utils/Android.bp index 0fec41c240..695905690e 100644 --- a/neuralnetworks/1.2/utils/Android.bp +++ b/neuralnetworks/1.2/utils/Android.bp @@ -18,6 +18,7 @@ cc_library_static { name: "neuralnetworks_utils_hal_1_2", defaults: ["neuralnetworks_utils_defaults"], srcs: ["src/*"], + exclude_srcs: ["src/ExecutionBurst*"], local_include_dirs: ["include/nnapi/hal/1.2/"], export_include_dirs: ["include"], cflags: ["-Wthread-safety"], diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h new file mode 100644 index 0000000000..e00ab82d69 --- /dev/null +++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h @@ -0,0 +1,345 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace android::nn { + +/** + * Number of elements in the FMQ. + */ +constexpr const size_t kExecutionBurstChannelLength = 1024; + +/** + * Function to serialize a request. + * + * Prefer calling RequestChannelSender::send. + * + * @param request Request object without the pool information. + * @param measure Whether to collect timing information for the execution. + * @param memoryIds Slot identifiers corresponding to memory resources for the + * request. + * @return Serialized FMQ request data. + */ +std::vector serialize( + const hardware::neuralnetworks::V1_0::Request& request, + hardware::neuralnetworks::V1_2::MeasureTiming measure, const std::vector& slots); + +/** + * Deserialize the FMQ result data. + * + * The three resulting fields are the status of the execution, the dynamic + * shapes of the output tensors, and the timing information of the execution. + * + * @param data Serialized FMQ result data. + * @return Result object if successfully deserialized, std::nullopt otherwise. + */ +std::optional, + hardware::neuralnetworks::V1_2::Timing>> +deserialize(const std::vector& data); + +/** + * Convert result code to error status. + * + * @param resultCode Result code to be converted. + * @return ErrorStatus Resultant error status. + */ +hardware::neuralnetworks::V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode); + +/** + * ResultChannelReceiver is responsible for waiting on the channel until the + * packet is available, extracting the packet from the channel, and + * deserializing the packet. + * + * Because the receiver can wait on a packet that may never come (e.g., because + * the sending side of the packet has been closed), this object can be + * invalidated, unblocking the receiver. + */ +class ResultChannelReceiver { + using FmqResultDescriptor = + hardware::MQDescriptorSync; + using FmqResultChannel = hardware::MessageQueue; + + public: + /** + * Create the receiving end of a result channel. + * + * Prefer this call over the constructor. + * + * @param channelLength Number of elements in the FMQ. + * @param pollingTimeWindow How much time (in microseconds) the + * ResultChannelReceiver is allowed to poll the FMQ before waiting on + * the blocking futex. Polling may result in lower latencies at the + * potential cost of more power usage. + * @return A pair of ResultChannelReceiver and the FMQ descriptor on + * successful creation, both nullptr otherwise. + */ + static std::pair, const FmqResultDescriptor*> create( + size_t channelLength, std::chrono::microseconds pollingTimeWindow); + + /** + * Get the result from the channel. + * + * This method will block until either: + * 1) The packet has been retrieved, or + * 2) The receiver has been invalidated + * + * @return Result object if successfully received, std::nullopt if error or + * if the receiver object was invalidated. + */ + std::optional, + hardware::neuralnetworks::V1_2::Timing>> + getBlocking(); + + /** + * Method to mark the channel as invalid, unblocking any current or future + * calls to ResultChannelReceiver::getBlocking. + */ + void invalidate(); + + // prefer calling ResultChannelReceiver::getBlocking + std::optional> getPacketBlocking(); + + ResultChannelReceiver(std::unique_ptr fmqResultChannel, + std::chrono::microseconds pollingTimeWindow); + + private: + const std::unique_ptr mFmqResultChannel; + std::atomic mValid{true}; + const std::chrono::microseconds kPollingTimeWindow; +}; + +/** + * RequestChannelSender is responsible for serializing the result packet of + * information, sending it on the result channel, and signaling that the data is + * available. + */ +class RequestChannelSender { + using FmqRequestDescriptor = + hardware::MQDescriptorSync; + using FmqRequestChannel = + hardware::MessageQueue; + + public: + /** + * Create the sending end of a request channel. + * + * Prefer this call over the constructor. + * + * @param channelLength Number of elements in the FMQ. + * @return A pair of ResultChannelReceiver and the FMQ descriptor on + * successful creation, both nullptr otherwise. + */ + static std::pair, const FmqRequestDescriptor*> create( + size_t channelLength); + + /** + * Send the request to the channel. + * + * @param request Request object without the pool information. + * @param measure Whether to collect timing information for the execution. + * @param memoryIds Slot identifiers corresponding to memory resources for + * the request. + * @return 'true' on successful send, 'false' otherwise. + */ + bool send(const hardware::neuralnetworks::V1_0::Request& request, + hardware::neuralnetworks::V1_2::MeasureTiming measure, + const std::vector& slots); + + /** + * Method to mark the channel as invalid, causing all future calls to + * RequestChannelSender::send to immediately return false without attempting + * to send a message across the FMQ. + */ + void invalidate(); + + // prefer calling RequestChannelSender::send + bool sendPacket(const std::vector& packet); + + RequestChannelSender(std::unique_ptr fmqRequestChannel); + + private: + const std::unique_ptr mFmqRequestChannel; + std::atomic mValid{true}; +}; + +/** + * The ExecutionBurstController class manages both the serialization and + * deserialization of data across FMQ, making it appear to the runtime as a + * regular synchronous inference. Additionally, this class manages the burst's + * memory cache. + */ +class ExecutionBurstController { + DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController); + + public: + /** + * NN runtime burst callback object and memory cache. + * + * ExecutionBurstCallback associates a hidl_memory object with a slot number + * to be passed across FMQ. The ExecutionBurstServer can use this callback + * to retrieve this hidl_memory corresponding to the slot via HIDL. + * + * Whenever a hidl_memory object is copied, it will duplicate the underlying + * file descriptor. Because the NN runtime currently copies the hidl_memory + * on each execution, it is difficult to associate hidl_memory objects with + * previously cached hidl_memory objects. For this reason, callers of this + * class must pair each hidl_memory object with an associated key. For + * efficiency, if two hidl_memory objects represent the same underlying + * buffer, they must use the same key. + */ + class ExecutionBurstCallback : public hardware::neuralnetworks::V1_2::IBurstCallback { + DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback); + + public: + ExecutionBurstCallback() = default; + + hardware::Return getMemories(const hardware::hidl_vec& slots, + getMemories_cb cb) override; + + /** + * This function performs one of two different actions: + * 1) If a key corresponding to a memory resource is unrecognized by the + * ExecutionBurstCallback object, the ExecutionBurstCallback object + * will allocate a slot, bind the memory to the slot, and return the + * slot identifier. + * 2) If a key corresponding to a memory resource is recognized by the + * ExecutionBurstCallback object, the ExecutionBurstCallback object + * will return the existing slot identifier. + * + * @param memories Memory resources used in an inference. + * @param keys Unique identifiers where each element corresponds to a + * memory resource element in "memories". + * @return Unique slot identifiers where each returned slot element + * corresponds to a memory resource element in "memories". + */ + std::vector getSlots(const hardware::hidl_vec& memories, + const std::vector& keys); + + /* + * This function performs two different actions: + * 1) Removes an entry from the cache (if present), including the local + * storage of the hidl_memory object. Note that this call does not + * free any corresponding hidl_memory object in ExecutionBurstServer, + * which is separately freed via IBurstContext::freeMemory. + * 2) Return whether a cache entry was removed and which slot was removed if + * found. If the key did not to correspond to any entry in the cache, a + * slot number of 0 is returned. The slot number and whether the entry + * existed is useful so the same slot can be freed in the + * ExecutionBurstServer's cache via IBurstContext::freeMemory. + */ + std::pair freeMemory(intptr_t key); + + private: + int32_t getSlotLocked(const hardware::hidl_memory& memory, intptr_t key); + int32_t allocateSlotLocked(); + + std::mutex mMutex; + std::stack> mFreeSlots; + std::map mMemoryIdToSlot; + std::vector mMemoryCache; + }; + + /** + * Creates a burst controller on a prepared model. + * + * Prefer this over ExecutionBurstController's constructor. + * + * @param preparedModel Model prepared for execution to execute on. + * @param pollingTimeWindow How much time (in microseconds) the + * ExecutionBurstController is allowed to poll the FMQ before waiting on + * the blocking futex. Polling may result in lower latencies at the + * potential cost of more power usage. + * @return ExecutionBurstController Execution burst controller object. + */ + static std::unique_ptr create( + const sp& preparedModel, + std::chrono::microseconds pollingTimeWindow); + + // prefer calling ExecutionBurstController::create + ExecutionBurstController(const std::shared_ptr& requestChannelSender, + const std::shared_ptr& resultChannelReceiver, + const sp& burstContext, + const sp& callback, + const sp& deathHandler = nullptr); + + // explicit destructor to unregister the death recipient + ~ExecutionBurstController(); + + /** + * Execute a request on a model. + * + * @param request Arguments to be executed on a model. + * @param measure Whether to collect timing measurements, either YES or NO + * @param memoryIds Identifiers corresponding to each memory object in the + * request's pools. + * @return A tuple of: + * - result code of the execution + * - dynamic output shapes from the execution + * - any execution time measurements of the execution + * - whether or not a failed burst execution should be re-run using a + * different path (e.g., IPreparedModel::executeSynchronously) + */ + std::tuple, + hardware::neuralnetworks::V1_2::Timing, bool> + compute(const hardware::neuralnetworks::V1_0::Request& request, + hardware::neuralnetworks::V1_2::MeasureTiming measure, + const std::vector& memoryIds); + + /** + * Propagate a user's freeing of memory to the service. + * + * @param key Key corresponding to the memory object. + */ + void freeMemory(intptr_t key); + + private: + std::mutex mMutex; + const std::shared_ptr mRequestChannelSender; + const std::shared_ptr mResultChannelReceiver; + const sp mBurstContext; + const sp mMemoryCache; + const sp mDeathHandler; +}; + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h new file mode 100644 index 0000000000..2c7d6540de --- /dev/null +++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h @@ -0,0 +1,343 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H +#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace android::nn { + +using FmqRequestDescriptor = + hardware::MQDescriptorSync; +using FmqResultDescriptor = + hardware::MQDescriptorSync; + +/** + * Function to serialize results. + * + * Prefer calling ResultChannelSender::send. + * + * @param errorStatus Status of the execution. + * @param outputShapes Dynamic shapes of the output tensors. + * @param timing Timing information of the execution. + * @return Serialized FMQ result data. + */ +std::vector serialize( + hardware::neuralnetworks::V1_0::ErrorStatus errorStatus, + const std::vector& outputShapes, + hardware::neuralnetworks::V1_2::Timing timing); + +/** + * Deserialize the FMQ request data. + * + * The three resulting fields are the Request object (where Request::pools is + * empty), slot identifiers (which are stand-ins for Request::pools), and + * whether timing information must be collected for the run. + * + * @param data Serialized FMQ request data. + * @return Request object if successfully deserialized, std::nullopt otherwise. + */ +std::optional, + hardware::neuralnetworks::V1_2::MeasureTiming>> +deserialize(const std::vector& data); + +/** + * RequestChannelReceiver is responsible for waiting on the channel until the + * packet is available, extracting the packet from the channel, and + * deserializing the packet. + * + * Because the receiver can wait on a packet that may never come (e.g., because + * the sending side of the packet has been closed), this object can be + * invalidated, unblocking the receiver. + */ +class RequestChannelReceiver { + using FmqRequestChannel = + hardware::MessageQueue; + + public: + /** + * Create the receiving end of a request channel. + * + * Prefer this call over the constructor. + * + * @param requestChannel Descriptor for the request channel. + * @param pollingTimeWindow How much time (in microseconds) the + * RequestChannelReceiver is allowed to poll the FMQ before waiting on + * the blocking futex. Polling may result in lower latencies at the + * potential cost of more power usage. + * @return RequestChannelReceiver on successful creation, nullptr otherwise. + */ + static std::unique_ptr create( + const FmqRequestDescriptor& requestChannel, + std::chrono::microseconds pollingTimeWindow); + + /** + * Get the request from the channel. + * + * This method will block until either: + * 1) The packet has been retrieved, or + * 2) The receiver has been invalidated + * + * @return Request object if successfully received, std::nullopt if error or + * if the receiver object was invalidated. + */ + std::optional, + hardware::neuralnetworks::V1_2::MeasureTiming>> + getBlocking(); + + /** + * Method to mark the channel as invalid, unblocking any current or future + * calls to RequestChannelReceiver::getBlocking. + */ + void invalidate(); + + RequestChannelReceiver(std::unique_ptr fmqRequestChannel, + std::chrono::microseconds pollingTimeWindow); + + private: + std::optional> getPacketBlocking(); + + const std::unique_ptr mFmqRequestChannel; + std::atomic mTeardown{false}; + const std::chrono::microseconds kPollingTimeWindow; +}; + +/** + * ResultChannelSender is responsible for serializing the result packet of + * information, sending it on the result channel, and signaling that the data is + * available. + */ +class ResultChannelSender { + using FmqResultChannel = hardware::MessageQueue; + + public: + /** + * Create the sending end of a result channel. + * + * Prefer this call over the constructor. + * + * @param resultChannel Descriptor for the result channel. + * @return ResultChannelSender on successful creation, nullptr otherwise. + */ + static std::unique_ptr create(const FmqResultDescriptor& resultChannel); + + /** + * Send the result to the channel. + * + * @param errorStatus Status of the execution. + * @param outputShapes Dynamic shapes of the output tensors. + * @param timing Timing information of the execution. + * @return 'true' on successful send, 'false' otherwise. + */ + bool send(hardware::neuralnetworks::V1_0::ErrorStatus errorStatus, + const std::vector& outputShapes, + hardware::neuralnetworks::V1_2::Timing timing); + + // prefer calling ResultChannelSender::send + bool sendPacket(const std::vector& packet); + + ResultChannelSender(std::unique_ptr fmqResultChannel); + + private: + const std::unique_ptr mFmqResultChannel; +}; + +/** + * The ExecutionBurstServer class is responsible for waiting for and + * deserializing a request object from a FMQ, performing the inference, and + * serializing the result back across another FMQ. + */ +class ExecutionBurstServer : public hardware::neuralnetworks::V1_2::IBurstContext { + DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer); + + public: + /** + * IBurstExecutorWithCache is a callback object passed to + * ExecutionBurstServer's factory function that is used to perform an + * execution. Because some memory resources are needed across multiple + * executions, this object also contains a local cache that can directly be + * used in the execution. + * + * ExecutionBurstServer will never access its IBurstExecutorWithCache object + * with concurrent calls. + */ + class IBurstExecutorWithCache { + DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache); + + public: + IBurstExecutorWithCache() = default; + virtual ~IBurstExecutorWithCache() = default; + + /** + * Checks if a cache entry specified by a slot is present in the cache. + * + * @param slot Identifier of the cache entry. + * @return 'true' if the cache entry is present in the cache, 'false' + * otherwise. + */ + virtual bool isCacheEntryPresent(int32_t slot) const = 0; + + /** + * Adds an entry specified by a slot to the cache. + * + * The caller of this function must ensure that the cache entry that is + * being added is not already present in the cache. This can be checked + * via isCacheEntryPresent. + * + * @param memory Memory resource to be cached. + * @param slot Slot identifier corresponding to the memory resource. + */ + virtual void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) = 0; + + /** + * Removes an entry specified by a slot from the cache. + * + * If the cache entry corresponding to the slot number does not exist, + * the call does nothing. + * + * @param slot Slot identifier corresponding to the memory resource. + */ + virtual void removeCacheEntry(int32_t slot) = 0; + + /** + * Perform an execution. + * + * @param request Request object with inputs and outputs specified. + * Request::pools is empty, and DataLocation::poolIndex instead + * refers to the 'slots' argument as if it were Request::pools. + * @param slots Slots corresponding to the cached memory entries to be + * used. + * @param measure Whether timing information is requested for the + * execution. + * @return Result of the execution, including the status of the + * execution, dynamic output shapes, and any timing information. + */ + virtual std::tuple, + hardware::neuralnetworks::V1_2::Timing> + execute(const hardware::neuralnetworks::V1_0::Request& request, + const std::vector& slots, + hardware::neuralnetworks::V1_2::MeasureTiming measure) = 0; + }; + + /** + * Create automated context to manage FMQ-based executions. + * + * This function is intended to be used by a service to automatically: + * 1) Receive data from a provided FMQ + * 2) Execute a model with the given information + * 3) Send the result to the created FMQ + * + * @param callback Callback used to retrieve memories corresponding to + * unrecognized slots. + * @param requestChannel Input FMQ channel through which the client passes the + * request to the service. + * @param resultChannel Output FMQ channel from which the client can retrieve + * the result of the execution. + * @param executorWithCache Object which maintains a local cache of the + * memory pools and executes using the cached memory pools. + * @param pollingTimeWindow How much time (in microseconds) the + * ExecutionBurstServer is allowed to poll the FMQ before waiting on + * the blocking futex. Polling may result in lower latencies at the + * potential cost of more power usage. + * @result IBurstContext Handle to the burst context. + */ + static sp create( + const sp& callback, + const FmqRequestDescriptor& requestChannel, const FmqResultDescriptor& resultChannel, + std::shared_ptr executorWithCache, + std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0}); + + /** + * Create automated context to manage FMQ-based executions. + * + * This function is intended to be used by a service to automatically: + * 1) Receive data from a provided FMQ + * 2) Execute a model with the given information + * 3) Send the result to the created FMQ + * + * @param callback Callback used to retrieve memories corresponding to + * unrecognized slots. + * @param requestChannel Input FMQ channel through which the client passes the + * request to the service. + * @param resultChannel Output FMQ channel from which the client can retrieve + * the result of the execution. + * @param preparedModel PreparedModel that the burst object was created from. + * IPreparedModel::executeSynchronously will be used to perform the + * execution. + * @param pollingTimeWindow How much time (in microseconds) the + * ExecutionBurstServer is allowed to poll the FMQ before waiting on + * the blocking futex. Polling may result in lower latencies at the + * potential cost of more power usage. + * @result IBurstContext Handle to the burst context. + */ + static sp create( + const sp& callback, + const FmqRequestDescriptor& requestChannel, const FmqResultDescriptor& resultChannel, + hardware::neuralnetworks::V1_2::IPreparedModel* preparedModel, + std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0}); + + ExecutionBurstServer(const sp& callback, + std::unique_ptr requestChannel, + std::unique_ptr resultChannel, + std::shared_ptr cachedExecutor); + ~ExecutionBurstServer(); + + // Used by the NN runtime to preemptively remove any stored memory. + hardware::Return freeMemory(int32_t slot) override; + + private: + // Ensures all cache entries contained in mExecutorWithCache are present in + // the cache. If they are not present, they are retrieved (via + // IBurstCallback::getMemories) and added to mExecutorWithCache. + // + // This method is locked via mMutex when it is called. + void ensureCacheEntriesArePresentLocked(const std::vector& slots); + + // Work loop that will continue processing execution requests until the + // ExecutionBurstServer object is freed. + void task(); + + std::thread mWorker; + std::mutex mMutex; + std::atomic mTeardown{false}; + const sp mCallback; + const std::unique_ptr mRequestChannelReceiver; + const std::unique_ptr mResultChannelSender; + const std::shared_ptr mExecutorWithCache; +}; + +} // namespace android::nn + +#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp new file mode 100644 index 0000000000..212863e183 --- /dev/null +++ b/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp @@ -0,0 +1,631 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "ExecutionBurstController" + +#include "ExecutionBurstController.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "HalInterfaces.h" +#include "Tracing.h" +#include "Utils.h" + +namespace android::nn { +namespace { + +using V1_2::FmqRequestDatum; +using V1_2::FmqResultDatum; +using V1_2::IBurstCallback; +using V1_2::IBurstContext; +using FmqRequestDescriptor = hardware::MQDescriptorSync; +using FmqResultDescriptor = hardware::MQDescriptorSync; + +constexpr V1_2::Timing kNoTiming12 = {std::numeric_limits::max(), + std::numeric_limits::max()}; + +class BurstContextDeathHandler : public hardware::hidl_death_recipient { + public: + using Callback = std::function; + + BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) { + CHECK(onDeathCallback != nullptr); + } + + void serviceDied(uint64_t /*cookie*/, const wp& /*who*/) override { + LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!"; + mOnDeathCallback(); + } + + private: + const Callback mOnDeathCallback; +}; + +} // anonymous namespace + +// serialize a request into a packet +std::vector serialize(const V1_0::Request& request, V1_2::MeasureTiming measure, + const std::vector& slots) { + // count how many elements need to be sent for a request + size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size(); + for (const auto& input : request.inputs) { + count += input.dimensions.size(); + } + for (const auto& output : request.outputs) { + count += output.dimensions.size(); + } + + // create buffer to temporarily store elements + std::vector data; + data.reserve(count); + + // package packetInfo + { + FmqRequestDatum datum; + datum.packetInformation( + {/*.packetSize=*/static_cast(count), + /*.numberOfInputOperands=*/static_cast(request.inputs.size()), + /*.numberOfOutputOperands=*/static_cast(request.outputs.size()), + /*.numberOfPools=*/static_cast(request.pools.size())}); + data.push_back(datum); + } + + // package input data + for (const auto& input : request.inputs) { + // package operand information + FmqRequestDatum datum; + datum.inputOperandInformation( + {/*.hasNoValue=*/input.hasNoValue, + /*.location=*/input.location, + /*.numberOfDimensions=*/static_cast(input.dimensions.size())}); + data.push_back(datum); + + // package operand dimensions + for (uint32_t dimension : input.dimensions) { + FmqRequestDatum datum; + datum.inputOperandDimensionValue(dimension); + data.push_back(datum); + } + } + + // package output data + for (const auto& output : request.outputs) { + // package operand information + FmqRequestDatum datum; + datum.outputOperandInformation( + {/*.hasNoValue=*/output.hasNoValue, + /*.location=*/output.location, + /*.numberOfDimensions=*/static_cast(output.dimensions.size())}); + data.push_back(datum); + + // package operand dimensions + for (uint32_t dimension : output.dimensions) { + FmqRequestDatum datum; + datum.outputOperandDimensionValue(dimension); + data.push_back(datum); + } + } + + // package pool identifier + for (int32_t slot : slots) { + FmqRequestDatum datum; + datum.poolIdentifier(slot); + data.push_back(datum); + } + + // package measureTiming + { + FmqRequestDatum datum; + datum.measureTiming(measure); + data.push_back(datum); + } + + // return packet + return data; +} + +// deserialize a packet into the result +std::optional, V1_2::Timing>> +deserialize(const std::vector& data) { + using discriminator = FmqResultDatum::hidl_discriminator; + + std::vector outputShapes; + size_t index = 0; + + // validate packet information + if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // unpackage packet information + const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation(); + index++; + const uint32_t packetSize = packetInfo.packetSize; + const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus; + const uint32_t numberOfOperands = packetInfo.numberOfOperands; + + // verify packet size + if (data.size() != packetSize) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // unpackage operands + for (size_t operand = 0; operand < numberOfOperands; ++operand) { + // validate operand information + if (data[index].getDiscriminator() != discriminator::operandInformation) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // unpackage operand information + const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation(); + index++; + const bool isSufficient = operandInfo.isSufficient; + const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; + + // unpackage operand dimensions + std::vector dimensions; + dimensions.reserve(numberOfDimensions); + for (size_t i = 0; i < numberOfDimensions; ++i) { + // validate dimension + if (data[index].getDiscriminator() != discriminator::operandDimensionValue) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // unpackage dimension + const uint32_t dimension = data[index].operandDimensionValue(); + index++; + + // store result + dimensions.push_back(dimension); + } + + // store result + outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient}); + } + + // validate execution timing + if (data[index].getDiscriminator() != discriminator::executionTiming) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // unpackage execution timing + const V1_2::Timing timing = data[index].executionTiming(); + index++; + + // validate packet information + if (index != packetSize) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // return result + return std::make_tuple(errorStatus, std::move(outputShapes), timing); +} + +V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) { + return convertToV1_0(convertResultCodeToErrorStatus(resultCode)); +} + +std::pair, const FmqResultDescriptor*> +ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) { + std::unique_ptr fmqResultChannel = + std::make_unique(channelLength, /*confEventFlag=*/true); + if (!fmqResultChannel->isValid()) { + LOG(ERROR) << "Unable to create ResultChannelReceiver"; + return {nullptr, nullptr}; + } + + const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc(); + return std::make_pair( + std::make_unique(std::move(fmqResultChannel), pollingTimeWindow), + descriptor); +} + +ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr fmqResultChannel, + std::chrono::microseconds pollingTimeWindow) + : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {} + +std::optional, V1_2::Timing>> +ResultChannelReceiver::getBlocking() { + const auto packet = getPacketBlocking(); + if (!packet) { + return std::nullopt; + } + + return deserialize(*packet); +} + +void ResultChannelReceiver::invalidate() { + mValid = false; + + // force unblock + // ExecutionBurstController waits on a result packet after sending a + // request. If the driver containing ExecutionBurstServer crashes, the + // controller may be waiting on the futex. This force unblock wakes up any + // thread waiting on the futex. + // TODO: look for a different/better way to signal/notify the futex to + // wake up any thread waiting on it + FmqResultDatum datum; + datum.packetInformation({/*.packetSize=*/0, + /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE, + /*.numberOfOperands=*/0}); + mFmqResultChannel->writeBlocking(&datum, 1); +} + +std::optional> ResultChannelReceiver::getPacketBlocking() { + if (!mValid) { + return std::nullopt; + } + + // First spend time polling if results are available in FMQ instead of + // waiting on the futex. Polling is more responsive (yielding lower + // latencies), but can take up more power, so only poll for a limited period + // of time. + + auto& getCurrentTime = std::chrono::high_resolution_clock::now; + const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow; + + while (getCurrentTime() < timeToStopPolling) { + // if class is being torn down, immediately return + if (!mValid.load(std::memory_order_relaxed)) { + return std::nullopt; + } + + // Check if data is available. If it is, immediately retrieve it and + // return. + const size_t available = mFmqResultChannel->availableToRead(); + if (available > 0) { + std::vector packet(available); + const bool success = mFmqResultChannel->read(packet.data(), available); + if (!success) { + LOG(ERROR) << "Error receiving packet"; + return std::nullopt; + } + return std::make_optional(std::move(packet)); + } + } + + // If we get to this point, we either stopped polling because it was taking + // too long or polling was not allowed. Instead, perform a blocking call + // which uses a futex to save power. + + // wait for result packet and read first element of result packet + FmqResultDatum datum; + bool success = mFmqResultChannel->readBlocking(&datum, 1); + + // retrieve remaining elements + // NOTE: all of the data is already available at this point, so there's no + // need to do a blocking wait to wait for more data. This is known because + // in FMQ, all writes are published (made available) atomically. Currently, + // the producer always publishes the entire packet in one function call, so + // if the first element of the packet is available, the remaining elements + // are also available. + const size_t count = mFmqResultChannel->availableToRead(); + std::vector packet(count + 1); + std::memcpy(&packet.front(), &datum, sizeof(datum)); + success &= mFmqResultChannel->read(packet.data() + 1, count); + + if (!mValid) { + return std::nullopt; + } + + // ensure packet was successfully received + if (!success) { + LOG(ERROR) << "Error receiving packet"; + return std::nullopt; + } + + return std::make_optional(std::move(packet)); +} + +std::pair, const FmqRequestDescriptor*> +RequestChannelSender::create(size_t channelLength) { + std::unique_ptr fmqRequestChannel = + std::make_unique(channelLength, /*confEventFlag=*/true); + if (!fmqRequestChannel->isValid()) { + LOG(ERROR) << "Unable to create RequestChannelSender"; + return {nullptr, nullptr}; + } + + const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc(); + return std::make_pair(std::make_unique(std::move(fmqRequestChannel)), + descriptor); +} + +RequestChannelSender::RequestChannelSender(std::unique_ptr fmqRequestChannel) + : mFmqRequestChannel(std::move(fmqRequestChannel)) {} + +bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure, + const std::vector& slots) { + const std::vector serialized = serialize(request, measure, slots); + return sendPacket(serialized); +} + +bool RequestChannelSender::sendPacket(const std::vector& packet) { + if (!mValid) { + return false; + } + + if (packet.size() > mFmqRequestChannel->availableToWrite()) { + LOG(ERROR) + << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ"; + return false; + } + + // Always send the packet with "blocking" because this signals the futex and + // unblocks the consumer if it is waiting on the futex. + return mFmqRequestChannel->writeBlocking(packet.data(), packet.size()); +} + +void RequestChannelSender::invalidate() { + mValid = false; +} + +hardware::Return ExecutionBurstController::ExecutionBurstCallback::getMemories( + const hardware::hidl_vec& slots, getMemories_cb cb) { + std::lock_guard guard(mMutex); + + // get all memories + hardware::hidl_vec memories(slots.size()); + std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) { + return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{}; + }); + + // ensure all memories are valid + if (!std::all_of(memories.begin(), memories.end(), + [](const hardware::hidl_memory& memory) { return memory.valid(); })) { + cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}); + return hardware::Void(); + } + + // return successful + cb(V1_0::ErrorStatus::NONE, std::move(memories)); + return hardware::Void(); +} + +std::vector ExecutionBurstController::ExecutionBurstCallback::getSlots( + const hardware::hidl_vec& memories, + const std::vector& keys) { + std::lock_guard guard(mMutex); + + // retrieve (or bind) all slots corresponding to memories + std::vector slots; + slots.reserve(memories.size()); + for (size_t i = 0; i < memories.size(); ++i) { + slots.push_back(getSlotLocked(memories[i], keys[i])); + } + return slots; +} + +std::pair ExecutionBurstController::ExecutionBurstCallback::freeMemory( + intptr_t key) { + std::lock_guard guard(mMutex); + + auto iter = mMemoryIdToSlot.find(key); + if (iter == mMemoryIdToSlot.end()) { + return {false, 0}; + } + const int32_t slot = iter->second; + mMemoryIdToSlot.erase(key); + mMemoryCache[slot] = {}; + mFreeSlots.push(slot); + return {true, slot}; +} + +int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked( + const hardware::hidl_memory& memory, intptr_t key) { + auto iter = mMemoryIdToSlot.find(key); + if (iter == mMemoryIdToSlot.end()) { + const int32_t slot = allocateSlotLocked(); + mMemoryIdToSlot[key] = slot; + mMemoryCache[slot] = memory; + return slot; + } else { + const int32_t slot = iter->second; + return slot; + } +} + +int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() { + constexpr size_t kMaxNumberOfSlots = std::numeric_limits::max(); + + // if there is a free slot, use it + if (mFreeSlots.size() > 0) { + const int32_t slot = mFreeSlots.top(); + mFreeSlots.pop(); + return slot; + } + + // otherwise use a slot for the first time + CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!"; + const int32_t slot = static_cast(mMemoryCache.size()); + mMemoryCache.emplace_back(); + + return slot; +} + +std::unique_ptr ExecutionBurstController::create( + const sp& preparedModel, + std::chrono::microseconds pollingTimeWindow) { + // check inputs + if (preparedModel == nullptr) { + LOG(ERROR) << "ExecutionBurstController::create passed a nullptr"; + return nullptr; + } + + // create callback object + sp callback = new ExecutionBurstCallback(); + + // create FMQ objects + auto [requestChannelSenderTemp, requestChannelDescriptor] = + RequestChannelSender::create(kExecutionBurstChannelLength); + auto [resultChannelReceiverTemp, resultChannelDescriptor] = + ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow); + std::shared_ptr requestChannelSender = + std::move(requestChannelSenderTemp); + std::shared_ptr resultChannelReceiver = + std::move(resultChannelReceiverTemp); + + // check FMQ objects + if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor || + !resultChannelDescriptor) { + LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue"; + return nullptr; + } + + // configure burst + V1_0::ErrorStatus errorStatus; + sp burstContext; + const hardware::Return ret = preparedModel->configureExecutionBurst( + callback, *requestChannelDescriptor, *resultChannelDescriptor, + [&errorStatus, &burstContext](V1_0::ErrorStatus status, + const sp& context) { + errorStatus = status; + burstContext = context; + }); + + // check burst + if (!ret.isOk()) { + LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description " + << ret.description(); + return nullptr; + } + if (errorStatus != V1_0::ErrorStatus::NONE) { + LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status " + << toString(errorStatus); + return nullptr; + } + if (burstContext == nullptr) { + LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst"; + return nullptr; + } + + // create death handler object + BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender, + resultChannelReceiver] { + requestChannelSender->invalidate(); + resultChannelReceiver->invalidate(); + }; + const sp deathHandler = new BurstContextDeathHandler(onDeathCallback); + + // linkToDeath registers a callback that will be invoked on service death to + // proactively handle service crashes. If the linkToDeath call fails, + // asynchronous calls are susceptible to hangs if the service crashes before + // providing the response. + const hardware::Return deathHandlerRet = burstContext->linkToDeath(deathHandler, 0); + if (!deathHandlerRet.isOk() || deathHandlerRet != true) { + LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient " + "for the IBurstContext object."; + return nullptr; + } + + // make and return controller + return std::make_unique(requestChannelSender, resultChannelReceiver, + burstContext, callback, deathHandler); +} + +ExecutionBurstController::ExecutionBurstController( + const std::shared_ptr& requestChannelSender, + const std::shared_ptr& resultChannelReceiver, + const sp& burstContext, const sp& callback, + const sp& deathHandler) + : mRequestChannelSender(requestChannelSender), + mResultChannelReceiver(resultChannelReceiver), + mBurstContext(burstContext), + mMemoryCache(callback), + mDeathHandler(deathHandler) {} + +ExecutionBurstController::~ExecutionBurstController() { + // It is safe to ignore any errors resulting from this unlinkToDeath call + // because the ExecutionBurstController object is already being destroyed + // and its underlying IBurstContext object is no longer being used by the NN + // runtime. + if (mDeathHandler) { + mBurstContext->unlinkToDeath(mDeathHandler).isOk(); + } +} + +static std::tuple, V1_2::Timing, bool> getExecutionResult( + V1_0::ErrorStatus status, std::vector outputShapes, V1_2::Timing timing, + bool fallback) { + auto [n, checkedOutputShapes, checkedTiming] = + getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing); + return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback}; +} + +std::tuple, V1_2::Timing, bool> +ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure, + const std::vector& memoryIds) { + // This is the first point when we know an execution is occurring, so begin + // to collect systraces. Note that the first point we can begin collecting + // systraces in ExecutionBurstServer is when the RequestChannelReceiver + // realizes there is data in the FMQ, so ExecutionBurstServer collects + // systraces at different points in the code. + NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute"); + + std::lock_guard guard(mMutex); + + // send request packet + const std::vector slots = mMemoryCache->getSlots(request.pools, memoryIds); + const bool success = mRequestChannelSender->send(request, measure, slots); + if (!success) { + LOG(ERROR) << "Error sending FMQ packet"; + // only use fallback execution path if the packet could not be sent + return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12, + /*fallback=*/true); + } + + // get result packet + const auto result = mResultChannelReceiver->getBlocking(); + if (!result) { + LOG(ERROR) << "Error retrieving FMQ packet"; + // only use fallback execution path if the packet could not be sent + return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12, + /*fallback=*/false); + } + + // unpack results and return (only use fallback execution path if the + // packet could not be sent) + auto [status, outputShapes, timing] = std::move(*result); + return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false); +} + +void ExecutionBurstController::freeMemory(intptr_t key) { + std::lock_guard guard(mMutex); + + bool valid; + int32_t slot; + std::tie(valid, slot) = mMemoryCache->freeMemory(key); + if (valid) { + mBurstContext->freeMemory(slot).isOk(); + } +} + +} // namespace android::nn diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp new file mode 100644 index 0000000000..848c77b284 --- /dev/null +++ b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp @@ -0,0 +1,646 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "ExecutionBurstServer" + +#include "ExecutionBurstServer.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "HalInterfaces.h" +#include "Tracing.h" + +namespace android::nn { +namespace { + +using hardware::MQDescriptorSync; +using V1_2::FmqRequestDatum; +using V1_2::FmqResultDatum; +using V1_2::IBurstCallback; +using V1_2::IBurstContext; + +constexpr V1_2::Timing kNoTiming = {std::numeric_limits::max(), + std::numeric_limits::max()}; + +// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be +// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the +// hidl_memory object, and the execution forwards calls to the provided +// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory +// must be mapped and unmapped for each execution. +class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache { + public: + DefaultBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel) + : mpPreparedModel(preparedModel) {} + + bool isCacheEntryPresent(int32_t slot) const override { + const auto it = mMemoryCache.find(slot); + return (it != mMemoryCache.end()) && it->second.valid(); + } + + void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override { + mMemoryCache[slot] = memory; + } + + void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); } + + std::tuple, V1_2::Timing> execute( + const V1_0::Request& request, const std::vector& slots, + V1_2::MeasureTiming measure) override { + // convert slots to pools + hardware::hidl_vec pools(slots.size()); + std::transform(slots.begin(), slots.end(), pools.begin(), + [this](int32_t slot) { return mMemoryCache[slot]; }); + + // create full request + V1_0::Request fullRequest = request; + fullRequest.pools = std::move(pools); + + // setup execution + V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE; + hardware::hidl_vec returnedOutputShapes; + V1_2::Timing returnedTiming; + auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming]( + V1_0::ErrorStatus status, + const hardware::hidl_vec& outputShapes, + const V1_2::Timing& timing) { + returnedStatus = status; + returnedOutputShapes = outputShapes; + returnedTiming = timing; + }; + + // execute + const hardware::Return ret = + mpPreparedModel->executeSynchronously(fullRequest, measure, cb); + if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE) { + LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing"; + return {returnedStatus, std::move(returnedOutputShapes), kNoTiming}; + } + + return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming); + } + + private: + V1_2::IPreparedModel* const mpPreparedModel; + std::map mMemoryCache; +}; + +} // anonymous namespace + +// serialize result +std::vector serialize(V1_0::ErrorStatus errorStatus, + const std::vector& outputShapes, + V1_2::Timing timing) { + // count how many elements need to be sent for a request + size_t count = 2 + outputShapes.size(); + for (const auto& outputShape : outputShapes) { + count += outputShape.dimensions.size(); + } + + // create buffer to temporarily store elements + std::vector data; + data.reserve(count); + + // package packetInfo + { + FmqResultDatum datum; + datum.packetInformation({/*.packetSize=*/static_cast(count), + /*.errorStatus=*/errorStatus, + /*.numberOfOperands=*/static_cast(outputShapes.size())}); + data.push_back(datum); + } + + // package output shape data + for (const auto& operand : outputShapes) { + // package operand information + FmqResultDatum::OperandInformation info{}; + info.isSufficient = operand.isSufficient; + info.numberOfDimensions = static_cast(operand.dimensions.size()); + + FmqResultDatum datum; + datum.operandInformation(info); + data.push_back(datum); + + // package operand dimensions + for (uint32_t dimension : operand.dimensions) { + FmqResultDatum datum; + datum.operandDimensionValue(dimension); + data.push_back(datum); + } + } + + // package executionTiming + { + FmqResultDatum datum; + datum.executionTiming(timing); + data.push_back(datum); + } + + // return result + return data; +} + +// deserialize request +std::optional, V1_2::MeasureTiming>> deserialize( + const std::vector& data) { + using discriminator = FmqRequestDatum::hidl_discriminator; + + size_t index = 0; + + // validate packet information + if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage packet information + const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation(); + index++; + const uint32_t packetSize = packetInfo.packetSize; + const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands; + const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands; + const uint32_t numberOfPools = packetInfo.numberOfPools; + + // verify packet size + if (data.size() != packetSize) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage input operands + std::vector inputs; + inputs.reserve(numberOfInputOperands); + for (size_t operand = 0; operand < numberOfInputOperands; ++operand) { + // validate input operand information + if (data[index].getDiscriminator() != discriminator::inputOperandInformation) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage operand information + const FmqRequestDatum::OperandInformation& operandInfo = + data[index].inputOperandInformation(); + index++; + const bool hasNoValue = operandInfo.hasNoValue; + const V1_0::DataLocation location = operandInfo.location; + const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; + + // unpackage operand dimensions + std::vector dimensions; + dimensions.reserve(numberOfDimensions); + for (size_t i = 0; i < numberOfDimensions; ++i) { + // validate dimension + if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage dimension + const uint32_t dimension = data[index].inputOperandDimensionValue(); + index++; + + // store result + dimensions.push_back(dimension); + } + + // store result + inputs.push_back( + {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions}); + } + + // unpackage output operands + std::vector outputs; + outputs.reserve(numberOfOutputOperands); + for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) { + // validate output operand information + if (data[index].getDiscriminator() != discriminator::outputOperandInformation) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage operand information + const FmqRequestDatum::OperandInformation& operandInfo = + data[index].outputOperandInformation(); + index++; + const bool hasNoValue = operandInfo.hasNoValue; + const V1_0::DataLocation location = operandInfo.location; + const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; + + // unpackage operand dimensions + std::vector dimensions; + dimensions.reserve(numberOfDimensions); + for (size_t i = 0; i < numberOfDimensions; ++i) { + // validate dimension + if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage dimension + const uint32_t dimension = data[index].outputOperandDimensionValue(); + index++; + + // store result + dimensions.push_back(dimension); + } + + // store result + outputs.push_back( + {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions}); + } + + // unpackage pools + std::vector slots; + slots.reserve(numberOfPools); + for (size_t pool = 0; pool < numberOfPools; ++pool) { + // validate input operand information + if (data[index].getDiscriminator() != discriminator::poolIdentifier) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage operand information + const int32_t poolId = data[index].poolIdentifier(); + index++; + + // store result + slots.push_back(poolId); + } + + // validate measureTiming + if (data[index].getDiscriminator() != discriminator::measureTiming) { + LOG(ERROR) << "FMQ Request packet ill-formed"; + return std::nullopt; + } + + // unpackage measureTiming + const V1_2::MeasureTiming measure = data[index].measureTiming(); + index++; + + // validate packet information + if (index != packetSize) { + LOG(ERROR) << "FMQ Result packet ill-formed"; + return std::nullopt; + } + + // return request + V1_0::Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}}; + return std::make_tuple(std::move(request), std::move(slots), measure); +} + +// RequestChannelReceiver methods + +std::unique_ptr RequestChannelReceiver::create( + const FmqRequestDescriptor& requestChannel, std::chrono::microseconds pollingTimeWindow) { + std::unique_ptr fmqRequestChannel = + std::make_unique(requestChannel); + + if (!fmqRequestChannel->isValid()) { + LOG(ERROR) << "Unable to create RequestChannelReceiver"; + return nullptr; + } + if (fmqRequestChannel->getEventFlagWord() == nullptr) { + LOG(ERROR) + << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag"; + return nullptr; + } + + return std::make_unique(std::move(fmqRequestChannel), + pollingTimeWindow); +} + +RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr fmqRequestChannel, + std::chrono::microseconds pollingTimeWindow) + : mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {} + +std::optional, V1_2::MeasureTiming>> +RequestChannelReceiver::getBlocking() { + const auto packet = getPacketBlocking(); + if (!packet) { + return std::nullopt; + } + + return deserialize(*packet); +} + +void RequestChannelReceiver::invalidate() { + mTeardown = true; + + // force unblock + // ExecutionBurstServer is by default waiting on a request packet. If the + // client process destroys its burst object, the server may still be waiting + // on the futex. This force unblock wakes up any thread waiting on the + // futex. + // TODO: look for a different/better way to signal/notify the futex to wake + // up any thread waiting on it + FmqRequestDatum datum; + datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0, + /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0}); + mFmqRequestChannel->writeBlocking(&datum, 1); +} + +std::optional> RequestChannelReceiver::getPacketBlocking() { + if (mTeardown) { + return std::nullopt; + } + + // First spend time polling if results are available in FMQ instead of + // waiting on the futex. Polling is more responsive (yielding lower + // latencies), but can take up more power, so only poll for a limited period + // of time. + + auto& getCurrentTime = std::chrono::high_resolution_clock::now; + const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow; + + while (getCurrentTime() < timeToStopPolling) { + // if class is being torn down, immediately return + if (mTeardown.load(std::memory_order_relaxed)) { + return std::nullopt; + } + + // Check if data is available. If it is, immediately retrieve it and + // return. + const size_t available = mFmqRequestChannel->availableToRead(); + if (available > 0) { + // This is the first point when we know an execution is occurring, + // so begin to collect systraces. Note that a similar systrace does + // not exist at the corresponding point in + // ResultChannelReceiver::getPacketBlocking because the execution is + // already in flight. + NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, + "ExecutionBurstServer getting packet"); + std::vector packet(available); + const bool success = mFmqRequestChannel->read(packet.data(), available); + if (!success) { + LOG(ERROR) << "Error receiving packet"; + return std::nullopt; + } + return std::make_optional(std::move(packet)); + } + } + + // If we get to this point, we either stopped polling because it was taking + // too long or polling was not allowed. Instead, perform a blocking call + // which uses a futex to save power. + + // wait for request packet and read first element of request packet + FmqRequestDatum datum; + bool success = mFmqRequestChannel->readBlocking(&datum, 1); + + // This is the first point when we know an execution is occurring, so begin + // to collect systraces. Note that a similar systrace does not exist at the + // corresponding point in ResultChannelReceiver::getPacketBlocking because + // the execution is already in flight. + NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet"); + + // retrieve remaining elements + // NOTE: all of the data is already available at this point, so there's no + // need to do a blocking wait to wait for more data. This is known because + // in FMQ, all writes are published (made available) atomically. Currently, + // the producer always publishes the entire packet in one function call, so + // if the first element of the packet is available, the remaining elements + // are also available. + const size_t count = mFmqRequestChannel->availableToRead(); + std::vector packet(count + 1); + std::memcpy(&packet.front(), &datum, sizeof(datum)); + success &= mFmqRequestChannel->read(packet.data() + 1, count); + + // terminate loop + if (mTeardown) { + return std::nullopt; + } + + // ensure packet was successfully received + if (!success) { + LOG(ERROR) << "Error receiving packet"; + return std::nullopt; + } + + return std::make_optional(std::move(packet)); +} + +// ResultChannelSender methods + +std::unique_ptr ResultChannelSender::create( + const FmqResultDescriptor& resultChannel) { + std::unique_ptr fmqResultChannel = + std::make_unique(resultChannel); + + if (!fmqResultChannel->isValid()) { + LOG(ERROR) << "Unable to create RequestChannelSender"; + return nullptr; + } + if (fmqResultChannel->getEventFlagWord() == nullptr) { + LOG(ERROR) << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag"; + return nullptr; + } + + return std::make_unique(std::move(fmqResultChannel)); +} + +ResultChannelSender::ResultChannelSender(std::unique_ptr fmqResultChannel) + : mFmqResultChannel(std::move(fmqResultChannel)) {} + +bool ResultChannelSender::send(V1_0::ErrorStatus errorStatus, + const std::vector& outputShapes, + V1_2::Timing timing) { + const std::vector serialized = serialize(errorStatus, outputShapes, timing); + return sendPacket(serialized); +} + +bool ResultChannelSender::sendPacket(const std::vector& packet) { + if (packet.size() > mFmqResultChannel->availableToWrite()) { + LOG(ERROR) + << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ"; + const std::vector errorPacket = + serialize(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming); + + // Always send the packet with "blocking" because this signals the futex + // and unblocks the consumer if it is waiting on the futex. + return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size()); + } + + // Always send the packet with "blocking" because this signals the futex and + // unblocks the consumer if it is waiting on the futex. + return mFmqResultChannel->writeBlocking(packet.data(), packet.size()); +} + +// ExecutionBurstServer methods + +sp ExecutionBurstServer::create( + const sp& callback, const MQDescriptorSync& requestChannel, + const MQDescriptorSync& resultChannel, + std::shared_ptr executorWithCache, + std::chrono::microseconds pollingTimeWindow) { + // check inputs + if (callback == nullptr || executorWithCache == nullptr) { + LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr"; + return nullptr; + } + + // create FMQ objects + std::unique_ptr requestChannelReceiver = + RequestChannelReceiver::create(requestChannel, pollingTimeWindow); + std::unique_ptr resultChannelSender = + ResultChannelSender::create(resultChannel); + + // check FMQ objects + if (!requestChannelReceiver || !resultChannelSender) { + LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue"; + return nullptr; + } + + // make and return context + return new ExecutionBurstServer(callback, std::move(requestChannelReceiver), + std::move(resultChannelSender), std::move(executorWithCache)); +} + +sp ExecutionBurstServer::create( + const sp& callback, const MQDescriptorSync& requestChannel, + const MQDescriptorSync& resultChannel, V1_2::IPreparedModel* preparedModel, + std::chrono::microseconds pollingTimeWindow) { + // check relevant input + if (preparedModel == nullptr) { + LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr"; + return nullptr; + } + + // adapt IPreparedModel to have caching + const std::shared_ptr preparedModelAdapter = + std::make_shared(preparedModel); + + // make and return context + return ExecutionBurstServer::create(callback, requestChannel, resultChannel, + preparedModelAdapter, pollingTimeWindow); +} + +ExecutionBurstServer::ExecutionBurstServer( + const sp& callback, std::unique_ptr requestChannel, + std::unique_ptr resultChannel, + std::shared_ptr executorWithCache) + : mCallback(callback), + mRequestChannelReceiver(std::move(requestChannel)), + mResultChannelSender(std::move(resultChannel)), + mExecutorWithCache(std::move(executorWithCache)) { + // TODO: highly document the threading behavior of this class + mWorker = std::thread([this] { task(); }); +} + +ExecutionBurstServer::~ExecutionBurstServer() { + // set teardown flag + mTeardown = true; + mRequestChannelReceiver->invalidate(); + + // wait for task thread to end + mWorker.join(); +} + +hardware::Return ExecutionBurstServer::freeMemory(int32_t slot) { + std::lock_guard hold(mMutex); + mExecutorWithCache->removeCacheEntry(slot); + return hardware::Void(); +} + +void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector& slots) { + const auto slotIsKnown = [this](int32_t slot) { + return mExecutorWithCache->isCacheEntryPresent(slot); + }; + + // find unique unknown slots + std::vector unknownSlots = slots; + auto unknownSlotsEnd = unknownSlots.end(); + std::sort(unknownSlots.begin(), unknownSlotsEnd); + unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd); + unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown); + unknownSlots.erase(unknownSlotsEnd, unknownSlots.end()); + + // quick-exit if all slots are known + if (unknownSlots.empty()) { + return; + } + + V1_0::ErrorStatus errorStatus = V1_0::ErrorStatus::GENERAL_FAILURE; + std::vector returnedMemories; + auto cb = [&errorStatus, &returnedMemories]( + V1_0::ErrorStatus status, + const hardware::hidl_vec& memories) { + errorStatus = status; + returnedMemories = memories; + }; + + const hardware::Return ret = mCallback->getMemories(unknownSlots, cb); + + if (!ret.isOk() || errorStatus != V1_0::ErrorStatus::NONE || + returnedMemories.size() != unknownSlots.size()) { + LOG(ERROR) << "Error retrieving memories"; + return; + } + + // add memories to unknown slots + for (size_t i = 0; i < unknownSlots.size(); ++i) { + mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]); + } +} + +void ExecutionBurstServer::task() { + // loop until the burst object is being destroyed + while (!mTeardown) { + // receive request + auto arguments = mRequestChannelReceiver->getBlocking(); + + // if the request packet was not properly received, return a generic + // error and skip the execution + // + // if the burst is being torn down, skip the execution so the "task" + // function can end + if (!arguments) { + if (!mTeardown) { + mResultChannelSender->send(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming); + } + continue; + } + + // otherwise begin tracing execution + NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, + "ExecutionBurstServer getting memory, executing, and returning results"); + + // unpack the arguments; types are Request, std::vector, and + // MeasureTiming, respectively + const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments); + + // ensure executor with cache has required memory + std::lock_guard hold(mMutex); + ensureCacheEntriesArePresentLocked(slotsOfPools); + + // perform computation; types are ErrorStatus, hidl_vec, + // and Timing, respectively + const auto [errorStatus, outputShapes, returnedTiming] = + mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure); + + // return result + mResultChannelSender->send(errorStatus, outputShapes, returnedTiming); + } +} + +} // namespace android::nn