Change NNAPI Memory to ref-counted SharedMemory -- hal

Bug: 179906132
Test: mma
Test: NeuralNetworksTest_static
Test: presubmit
Change-Id: I6435db906a2efe4938da18149a1fcd6d24730a95
Merged-In: I6435db906a2efe4938da18149a1fcd6d24730a95
(cherry picked from commit 79a16ebb6f)
This commit is contained in:
Michael Butler
2021-02-07 00:11:13 -08:00
parent 0ace84a193
commit fadeb8a920
26 changed files with 91 additions and 69 deletions

View File

@@ -41,7 +41,7 @@ class Burst final : public nn::IBurst {
Burst(PrivateConstructorTag tag, nn::SharedPreparedModel preparedModel);
OptionalCacheHold cacheMemory(const nn::Memory& memory) const override;
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure) const override;

View File

@@ -36,7 +36,7 @@ GeneralResult<Operand> unvalidatedConvert(const hal::V1_0::Operand& operand);
GeneralResult<Operation> unvalidatedConvert(const hal::V1_0::Operation& operation);
GeneralResult<Model::OperandValues> unvalidatedConvert(
const hardware::hidl_vec<uint8_t>& operandValues);
GeneralResult<Memory> unvalidatedConvert(const hardware::hidl_memory& memory);
GeneralResult<SharedMemory> unvalidatedConvert(const hardware::hidl_memory& memory);
GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model);
GeneralResult<Request::Argument> unvalidatedConvert(
const hal::V1_0::RequestArgument& requestArgument);
@@ -65,7 +65,7 @@ nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand);
nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation);
nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
const nn::Model::OperandValues& operandValues);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<RequestArgument> unvalidatedConvert(const nn::Request::Argument& requestArgument);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool);

View File

@@ -43,7 +43,7 @@ Burst::Burst(PrivateConstructorTag /*tag*/, nn::SharedPreparedModel preparedMode
CHECK(kPreparedModel != nullptr);
}
Burst::OptionalCacheHold Burst::cacheMemory(const nn::Memory& /*memory*/) const {
Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& /*memory*/) const {
return nullptr;
}

View File

@@ -153,7 +153,7 @@ GeneralResult<Model::OperandValues> unvalidatedConvert(const hidl_vec<uint8_t>&
return Model::OperandValues(operandValues.data(), operandValues.size());
}
GeneralResult<Memory> unvalidatedConvert(const hidl_memory& memory) {
GeneralResult<SharedMemory> unvalidatedConvert(const hidl_memory& memory) {
return createSharedMemoryFromHidlMemory(memory);
}
@@ -346,9 +346,10 @@ nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
return hidl_vec<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
}
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
return hidl_memory(memory.name, NN_TRY(hal::utils::hidlHandleFromSharedHandle(memory.handle)),
memory.size);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
CHECK(memory != nullptr);
return hidl_memory(memory->name, NN_TRY(hal::utils::hidlHandleFromSharedHandle(memory->handle)),
memory->size);
}
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
@@ -392,7 +393,7 @@ nn::GeneralResult<RequestArgument> unvalidatedConvert(
}
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
return unvalidatedConvert(std::get<nn::Memory>(memoryPool));
return unvalidatedConvert(std::get<nn::SharedMemory>(memoryPool));
}
nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {

View File

@@ -175,7 +175,7 @@ nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
return V1_0::utils::unvalidatedConvert(operandValues);
}
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
return V1_0::utils::unvalidatedConvert(memory);
}

View File

@@ -365,7 +365,7 @@ nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
return V1_0::utils::unvalidatedConvert(operandValues);
}
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
return V1_0::utils::unvalidatedConvert(memory);
}

View File

@@ -42,8 +42,8 @@ class Buffer final : public nn::IBuffer {
nn::Request::MemoryDomainToken getToken() const override;
nn::GeneralResult<void> copyTo(const nn::Memory& dst) const override;
nn::GeneralResult<void> copyFrom(const nn::Memory& src,
nn::GeneralResult<void> copyTo(const nn::SharedMemory& dst) const override;
nn::GeneralResult<void> copyFrom(const nn::SharedMemory& src,
const nn::Dimensions& dimensions) const override;
private:

View File

@@ -59,7 +59,7 @@ GeneralResult<OptionalDuration> convert(
GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus);
GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle);
GeneralResult<Memory> convert(const hardware::hidl_memory& memory);
GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory);
GeneralResult<std::vector<BufferRole>> convert(
const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles);
@@ -100,7 +100,7 @@ nn::GeneralResult<OptionalTimeoutDuration> convert(
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus);
nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle);
nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory);
nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory);
nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles);
nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus);

View File

@@ -61,7 +61,7 @@ nn::Request::MemoryDomainToken Buffer::getToken() const {
return kToken;
}
nn::GeneralResult<void> Buffer::copyTo(const nn::Memory& dst) const {
nn::GeneralResult<void> Buffer::copyTo(const nn::SharedMemory& dst) const {
const auto hidlDst = NN_TRY(convert(dst));
const auto ret = kBuffer->copyTo(hidlDst);
@@ -71,7 +71,7 @@ nn::GeneralResult<void> Buffer::copyTo(const nn::Memory& dst) const {
return {};
}
nn::GeneralResult<void> Buffer::copyFrom(const nn::Memory& src,
nn::GeneralResult<void> Buffer::copyFrom(const nn::SharedMemory& src,
const nn::Dimensions& dimensions) const {
const auto hidlSrc = NN_TRY(convert(src));
const auto hidlDimensions = hidl_vec<uint32_t>(dimensions);

View File

@@ -352,7 +352,7 @@ GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
return validatedConvert(handle);
}
GeneralResult<Memory> convert(const hardware::hidl_memory& memory) {
GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
return validatedConvert(memory);
}
@@ -386,7 +386,7 @@ nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle
return V1_2::utils::unvalidatedConvert(handle);
}
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
return V1_0::utils::unvalidatedConvert(memory);
}
@@ -424,7 +424,7 @@ nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
return unvalidatedConvertVec(arguments);
}
nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Memory& memory) {
nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
Request::MemoryPool ret;
ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
return ret;
@@ -677,7 +677,7 @@ nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
return validatedConvert(handle);
}
nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory) {
nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory) {
return validatedConvert(memory);
}

View File

@@ -79,7 +79,7 @@ GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t
GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph);
GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape);
GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming);
GeneralResult<Memory> unvalidatedConvert(const aidl_hal::Memory& memory);
GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory);
GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing);
GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc);
GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole);
@@ -99,7 +99,7 @@ GeneralResult<SharedHandle> unvalidatedConvert(
GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference);
GeneralResult<Memory> convert(const aidl_hal::Memory& memory);
GeneralResult<SharedMemory> convert(const aidl_hal::Memory& memory);
GeneralResult<Model> convert(const aidl_hal::Model& model);
GeneralResult<Operand> convert(const aidl_hal::Operand& operand);
GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType);
@@ -108,7 +108,7 @@ GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& me
GeneralResult<Request> convert(const aidl_hal::Request& request);
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& outputShapes);
GeneralResult<std::vector<Memory>> convert(const std::vector<aidl_hal::Memory>& memories);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec);
@@ -118,11 +118,11 @@ namespace aidl::android::hardware::neuralnetworks::utils {
namespace nn = ::android::nn;
nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory& memory);
nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory);
nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape);
nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus);
nn::GeneralResult<Memory> convert(const nn::Memory& memory);
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory);
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus);
nn::GeneralResult<std::vector<OutputShape>> convert(
const std::vector<nn::OutputShape>& outputShapes);

View File

@@ -53,6 +53,8 @@ constexpr auto kVersion = android::nn::Version::ANDROID_S;
namespace android::nn {
namespace {
using ::aidl::android::hardware::common::NativeHandle;
constexpr auto validOperandType(nn::OperandType operandType) {
switch (operandType) {
case nn::OperandType::FLOAT32:
@@ -316,13 +318,13 @@ GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
}
GeneralResult<Memory> unvalidatedConvert(const aidl_hal::Memory& memory) {
GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
VERIFY_NON_NEGATIVE(memory.size) << "Memory size must not be negative";
return Memory{
return std::make_shared<const Memory>(Memory{
.handle = NN_TRY(unvalidatedConvert(memory.handle)),
.size = static_cast<uint32_t>(memory.size),
.name = memory.name,
};
});
}
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
@@ -397,8 +399,7 @@ GeneralResult<ExecutionPreference> unvalidatedConvert(
return static_cast<ExecutionPreference>(executionPreference);
}
GeneralResult<SharedHandle> unvalidatedConvert(
const ::aidl::android::hardware::common::NativeHandle& aidlNativeHandle) {
GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHandle) {
std::vector<base::unique_fd> fds;
fds.reserve(aidlNativeHandle.fds.size());
for (const auto& fd : aidlNativeHandle.fds) {
@@ -422,7 +423,7 @@ GeneralResult<ExecutionPreference> convert(
return validatedConvert(executionPreference);
}
GeneralResult<Memory> convert(const aidl_hal::Memory& operand) {
GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
return validatedConvert(operand);
}
@@ -454,7 +455,7 @@ GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operat
return unvalidatedConvert(operations);
}
GeneralResult<std::vector<Memory>> convert(const std::vector<aidl_hal::Memory>& memories) {
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
return validatedConvert(memories);
}
@@ -525,14 +526,15 @@ nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandl
return aidlNativeHandle;
}
nn::GeneralResult<Memory> unvalidatedConvert(const nn::Memory& memory) {
if (memory.size > std::numeric_limits<int64_t>::max()) {
nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
CHECK(memory != nullptr);
if (memory->size > std::numeric_limits<int64_t>::max()) {
return NN_ERROR() << "Memory size doesn't fit into int64_t.";
}
return Memory{
.handle = NN_TRY(unvalidatedConvert(memory.handle)),
.size = static_cast<int64_t>(memory.size),
.name = memory.name,
.handle = NN_TRY(unvalidatedConvert(memory->handle)),
.size = static_cast<int64_t>(memory->size),
.name = memory->name,
};
}
@@ -558,7 +560,7 @@ nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputS
.isSufficient = outputShape.isSufficient};
}
nn::GeneralResult<Memory> convert(const nn::Memory& memory) {
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
return validatedConvert(memory);
}

View File

@@ -266,7 +266,7 @@ Model createModel(const TestModel& testModel) {
copyTestBuffers(constCopies, operandValues.data());
// Shared memory.
std::vector<nn::Memory> pools = {};
std::vector<nn::SharedMemory> pools = {};
if (constRefSize > 0) {
const auto pool = nn::createSharedMemory(constRefSize).value();
pools.push_back(pool);

View File

@@ -74,7 +74,7 @@ nn::GeneralResult<void> unflushDataFromSharedToPointer(
std::vector<uint32_t> countNumberOfConsumers(size_t numberOfOperands,
const std::vector<nn::Operation>& operations);
nn::GeneralResult<nn::Memory> createSharedMemoryFromHidlMemory(const hidl_memory& memory);
nn::GeneralResult<nn::SharedMemory> createSharedMemoryFromHidlMemory(const hidl_memory& memory);
nn::GeneralResult<hidl_handle> hidlHandleFromSharedHandle(const nn::SharedHandle& handle);
nn::GeneralResult<nn::SharedHandle> sharedHandleFromNativeHandle(const native_handle_t* handle);

View File

@@ -14,6 +14,9 @@
* limitations under the License.
*/
#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_COMMON_HANDLE_ERROR_H
#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_COMMON_HANDLE_ERROR_H
#include <android/hidl/base/1.0/IBase.h>
#include <hidl/HidlSupport.h>
#include <nnapi/Result.h>
@@ -50,7 +53,8 @@ nn::GeneralResult<Type> handleTransportError(const hardware::Return<Type>& ret)
})
template <typename Type>
nn::GeneralResult<Type> makeGeneralFailure(nn::Result<Type> result, nn::ErrorStatus status) {
nn::GeneralResult<Type> makeGeneralFailure(
nn::Result<Type> result, nn::ErrorStatus status = nn::ErrorStatus::GENERAL_FAILURE) {
if (!result.has_value()) {
return nn::error(status) << std::move(result).error();
}
@@ -75,7 +79,8 @@ nn::ExecutionResult<Type> makeExecutionFailure(nn::GeneralResult<Type> result) {
}
template <typename Type>
nn::ExecutionResult<Type> makeExecutionFailure(nn::Result<Type> result, nn::ErrorStatus status) {
nn::ExecutionResult<Type> makeExecutionFailure(
nn::Result<Type> result, nn::ErrorStatus status = nn::ErrorStatus::GENERAL_FAILURE) {
return makeExecutionFailure(makeGeneralFailure(result, status));
}
@@ -86,4 +91,6 @@ nn::ExecutionResult<Type> makeExecutionFailure(nn::Result<Type> result, nn::Erro
} else \
return NN_ERROR(canonical)
} // namespace android::hardware::neuralnetworks::utils
} // namespace android::hardware::neuralnetworks::utils
#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_COMMON_HANDLE_ERROR_H

View File

@@ -31,9 +31,9 @@ class InvalidBuffer final : public nn::IBuffer {
public:
nn::Request::MemoryDomainToken getToken() const override;
nn::GeneralResult<void> copyTo(const nn::Memory& dst) const override;
nn::GeneralResult<void> copyTo(const nn::SharedMemory& dst) const override;
nn::GeneralResult<void> copyFrom(const nn::Memory& src,
nn::GeneralResult<void> copyFrom(const nn::SharedMemory& src,
const nn::Dimensions& dimensions) const override;
};

View File

@@ -29,7 +29,7 @@ namespace android::hardware::neuralnetworks::utils {
class InvalidBurst final : public nn::IBurst {
public:
OptionalCacheHold cacheMemory(const nn::Memory& memory) const override;
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure) const override;

View File

@@ -46,9 +46,9 @@ class ResilientBuffer final : public nn::IBuffer {
nn::Request::MemoryDomainToken getToken() const override;
nn::GeneralResult<void> copyTo(const nn::Memory& dst) const override;
nn::GeneralResult<void> copyTo(const nn::SharedMemory& dst) const override;
nn::GeneralResult<void> copyFrom(const nn::Memory& src,
nn::GeneralResult<void> copyFrom(const nn::SharedMemory& src,
const nn::Dimensions& dimensions) const override;
private:

View File

@@ -44,7 +44,7 @@ class ResilientBurst final : public nn::IBurst,
nn::SharedBurst getBurst() const;
nn::GeneralResult<nn::SharedBurst> recover(const nn::IBurst* failingBurst) const;
OptionalCacheHold cacheMemory(const nn::Memory& memory) const override;
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure) const override;

View File

@@ -203,13 +203,13 @@ nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointe
nn::GeneralResult<void> unflushDataFromSharedToPointer(
const nn::Request& request, const std::optional<nn::Request>& maybeRequestInShared) {
if (!maybeRequestInShared.has_value() || maybeRequestInShared->pools.empty() ||
!std::holds_alternative<nn::Memory>(maybeRequestInShared->pools.back())) {
!std::holds_alternative<nn::SharedMemory>(maybeRequestInShared->pools.back())) {
return {};
}
const auto& requestInShared = *maybeRequestInShared;
// Map the memory.
const auto& outputMemory = std::get<nn::Memory>(requestInShared.pools.back());
const auto& outputMemory = std::get<nn::SharedMemory>(requestInShared.pools.back());
const auto [pointer, size, context] = NN_TRY(map(outputMemory));
const uint8_t* constantPointer =
std::visit([](const auto& o) { return static_cast<const uint8_t*>(o); }, pointer);

View File

@@ -30,11 +30,11 @@ nn::Request::MemoryDomainToken InvalidBuffer::getToken() const {
return nn::Request::MemoryDomainToken{};
}
nn::GeneralResult<void> InvalidBuffer::copyTo(const nn::Memory& /*dst*/) const {
nn::GeneralResult<void> InvalidBuffer::copyTo(const nn::SharedMemory& /*dst*/) const {
return NN_ERROR() << "InvalidBuffer";
}
nn::GeneralResult<void> InvalidBuffer::copyFrom(const nn::Memory& /*src*/,
nn::GeneralResult<void> InvalidBuffer::copyFrom(const nn::SharedMemory& /*src*/,
const nn::Dimensions& /*dimensions*/) const {
return NN_ERROR() << "InvalidBuffer";
}

View File

@@ -26,7 +26,8 @@
namespace android::hardware::neuralnetworks::utils {
InvalidBurst::OptionalCacheHold InvalidBurst::cacheMemory(const nn::Memory& /*memory*/) const {
InvalidBurst::OptionalCacheHold InvalidBurst::cacheMemory(
const nn::SharedMemory& /*memory*/) const {
return nullptr;
}

View File

@@ -99,12 +99,12 @@ nn::Request::MemoryDomainToken ResilientBuffer::getToken() const {
return getBuffer()->getToken();
}
nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const {
nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::SharedMemory& dst) const {
const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
return protect(*this, fn);
}
nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src,
nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::SharedMemory& src,
const nn::Dimensions& dimensions) const {
const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
return buffer.copyFrom(src, dimensions);

View File

@@ -94,7 +94,8 @@ nn::GeneralResult<nn::SharedBurst> ResilientBurst::recover(const nn::IBurst* fai
return mBurst;
}
ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(const nn::Memory& memory) const {
ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(
const nn::SharedMemory& memory) const {
return getBurst()->cacheMemory(memory);
}

View File

@@ -27,9 +27,9 @@ namespace android::nn {
class MockBuffer final : public IBuffer {
public:
MOCK_METHOD(Request::MemoryDomainToken, getToken, (), (const, override));
MOCK_METHOD(GeneralResult<void>, copyTo, (const Memory& dst), (const, override));
MOCK_METHOD(GeneralResult<void>, copyFrom, (const Memory& src, const Dimensions& dimensions),
(const, override));
MOCK_METHOD(GeneralResult<void>, copyTo, (const SharedMemory& dst), (const, override));
MOCK_METHOD(GeneralResult<void>, copyFrom,
(const SharedMemory& src, const Dimensions& dimensions), (const, override));
};
} // namespace android::nn

View File

@@ -15,9 +15,11 @@
*/
#include <gmock/gmock.h>
#include <nnapi/SharedMemory.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/hal/ResilientBuffer.h>
#include <memory>
#include <tuple>
#include <utility>
#include "MockBuffer.h"
@@ -113,7 +115,8 @@ TEST(ResilientBufferTest, copyTo) {
EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(Return(kNoError));
// run test
const auto result = buffer->copyTo({});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyTo(memory);
// verify result
ASSERT_TRUE(result.has_value())
@@ -126,7 +129,8 @@ TEST(ResilientBufferTest, copyToError) {
EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = buffer->copyTo({});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyTo(memory);
// verify result
ASSERT_FALSE(result.has_value());
@@ -140,7 +144,8 @@ TEST(ResilientBufferTest, copyToDeadObjectFailedRecovery) {
EXPECT_CALL(*mockBufferFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = buffer->copyTo({});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyTo(memory);
// verify result
ASSERT_FALSE(result.has_value());
@@ -156,7 +161,8 @@ TEST(ResilientBufferTest, copyToDeadObjectSuccessfulRecovery) {
EXPECT_CALL(*mockBufferFactory, Call()).Times(1).WillOnce(Return(recoveredMockBuffer));
// run test
const auto result = buffer->copyTo({});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyTo(memory);
// verify result
ASSERT_TRUE(result.has_value())
@@ -169,7 +175,8 @@ TEST(ResilientBufferTest, copyFrom) {
EXPECT_CALL(*mockBuffer, copyFrom(_, _)).Times(1).WillOnce(Return(kNoError));
// run test
const auto result = buffer->copyFrom({}, {});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyFrom(memory, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -182,7 +189,8 @@ TEST(ResilientBufferTest, copyFromError) {
EXPECT_CALL(*mockBuffer, copyFrom(_, _)).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = buffer->copyFrom({}, {});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyFrom(memory, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -196,7 +204,8 @@ TEST(ResilientBufferTest, copyFromDeadObjectFailedRecovery) {
EXPECT_CALL(*mockBufferFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = buffer->copyFrom({}, {});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyFrom(memory, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -212,7 +221,8 @@ TEST(ResilientBufferTest, copyFromDeadObjectSuccessfulRecovery) {
EXPECT_CALL(*mockBufferFactory, Call()).Times(1).WillOnce(Return(recoveredMockBuffer));
// run test
const auto result = buffer->copyFrom({}, {});
const nn::SharedMemory memory = std::make_shared<const nn::Memory>();
const auto result = buffer->copyFrom(memory, {});
// verify result
ASSERT_TRUE(result.has_value())