HAL interface for compilation and execution hints

The following AIDL types are added:
 - TokenValuePair
 - PrepareModelConfig
 - ExecutionConfig

The following AIDL methods are added:
 - IDevice::prepareModelWithConfig
 - IPreparedModel::executeSynchronouslyWithConfig
 - IPreparedModel::executeFencedWithConfig
 - IBurst::executeSynchronouslyWithConfig

The compilation and execution hints are being stored as a list of
token-value pairs as part of the PrepareModelConfig / ExecutionConfig.
And the PrepareModelConfig / ExecutionConfig parcelables are created in
order to make future extensions to the execution related interfaces
easier.

It is the drivers responsibility to verify the hints, and it is allowed
for the driver to ignore them.

Bug: 203248587
Test: neuralnetworks_utils_hal_aidl_test
Change-Id: I98240fd75089fc85cdfcaa0be28aab8a6f0dfca5
This commit is contained in:
Miao Wang
2021-10-26 20:03:05 +00:00
parent a42956fbe8
commit 0e671f3edb
90 changed files with 1913 additions and 473 deletions

View File

@@ -45,12 +45,15 @@ class Burst final : public nn::IBurst {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
private:
const nn::SharedPreparedModel kPreparedModel;

View File

@@ -65,8 +65,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -49,18 +49,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -50,15 +50,20 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& /*memory*/)
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
}
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
}
} // namespace android::hardware::neuralnetworks::V1_0::utils

View File

@@ -143,7 +143,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference /*preference*/, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -59,7 +59,9 @@ PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/, sp<V1_0::IPreparedMo
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -94,19 +96,22 @@ PreparedModel::executeInternal(const V1_0::Request& request,
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& /*request*/,
const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
PreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "IPreparedModel::executeFenced is not supported on 1.0 HAL service";
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -380,7 +380,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -399,7 +399,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -417,7 +417,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -435,7 +435,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -452,7 +452,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -469,7 +469,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -488,7 +488,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -121,7 +121,7 @@ TEST(PreparedModelTest, execute) {
.WillOnce(Invoke(makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -138,7 +138,7 @@ TEST(PreparedModelTest, executeLaunchError) {
V1_0::ErrorStatus::GENERAL_FAILURE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -155,7 +155,7 @@ TEST(PreparedModelTest, executeReturnError) {
makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -171,7 +171,7 @@ TEST(PreparedModelTest, executeTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -187,7 +187,7 @@ TEST(PreparedModelTest, executeDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -205,7 +205,7 @@ TEST(PreparedModelTest, executeCrash) {
EXPECT_CALL(*mockPreparedModel, execute(_, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -218,7 +218,7 @@ TEST(PreparedModelTest, executeFencedNotSupported) {
const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -235,7 +235,7 @@ TEST(PreparedModelTest, reusableExecute) {
.WillRepeatedly(Invoke(makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -258,7 +258,7 @@ TEST(PreparedModelTest, reusableExecuteLaunchError) {
V1_0::ErrorStatus::GENERAL_FAILURE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -279,7 +279,7 @@ TEST(PreparedModelTest, reusableExecuteReturnError) {
makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -299,7 +299,7 @@ TEST(PreparedModelTest, reusableExecuteTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -319,7 +319,7 @@ TEST(PreparedModelTest, reusableExecuteDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -341,7 +341,7 @@ TEST(PreparedModelTest, reusableExecuteCrash) {
EXPECT_CALL(*mockPreparedModel, execute(_, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -358,7 +358,7 @@ TEST(PreparedModelTest, reusableExecuteFencedNotSupported) {
const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -64,8 +64,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -143,7 +143,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -390,7 +390,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -409,7 +409,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -427,7 +427,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -445,7 +445,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -462,7 +462,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -479,7 +479,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -498,7 +498,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -170,13 +170,16 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
// See IBurst::execute for information on this method.
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// See IBurst::createReusableExecution for information on this method.
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// If fallback is not nullptr, this method will invoke the fallback function to try another
// execution path if the packet could not be sent. Otherwise, failing to send the packet will

View File

@@ -37,7 +37,7 @@ GeneralResult<Operand> unvalidatedConvert(const hal::V1_2::Operand& operand);
GeneralResult<Operand::ExtraParams> unvalidatedConvert(
const hal::V1_2::Operand::ExtraParams& extraParams);
GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model);
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
GeneralResult<OutputShape> unvalidatedConvert(const hal::V1_2::OutputShape& outputShape);
GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& measureTiming);
@@ -78,7 +78,7 @@ nn::GeneralResult<Operand::ExtraParams> unvalidatedConvert(
const nn::Operand::ExtraParams& extraParams);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix);
nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape);
nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing);

View File

@@ -83,8 +83,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -49,18 +49,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -305,8 +305,9 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) cons
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// This is the first point when we know an execution is occurring, so begin to collect
// systraces. Note that the first point we can begin collecting systraces in
// ExecutionBurstServer is when the RequestChannelReceiver realizes there is data in the FMQ, so
@@ -317,7 +318,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// fall back to another execution path
if (!compliantVersion(request).ok()) {
// fallback to another execution path if the packet could not be sent
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, {}, {});
}
// ensure that request is ready for IPC
@@ -346,7 +347,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// send request packet
const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
const auto fallback = [this, &request, measure, &deadline, &loopTimeoutDuration] {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, {}, {});
};
return executeInternal(requestPacket, relocation, fallback);
}
@@ -354,14 +355,17 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// See IBurst::createReusableExecution for information on this method.
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "Burst::createReusableExecution");
// if the request is valid but of a higher version than what's supported in burst execution,
// fall back to another execution path
if (!compliantVersion(request).ok()) {
// fallback to another execution path if the packet could not be sent
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration, {},
{});
}
// ensure that request is ready for IPC

View File

@@ -212,9 +212,9 @@ GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
};
}
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
return ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,
};
@@ -495,7 +495,7 @@ nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
}
nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,

View File

@@ -236,7 +236,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -91,7 +91,9 @@ PreparedModel::executeAsynchronously(const V1_0::Request& request, MeasureTiming
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -123,19 +125,22 @@ PreparedModel::executeInternal(const V1_0::Request& request, MeasureTiming measu
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& /*request*/,
const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
PreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "IPreparedModel::executeFenced is not supported on 1.2 HAL service";
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -636,7 +636,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -655,7 +655,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -673,7 +673,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -691,7 +691,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -708,7 +708,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -725,7 +725,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -746,7 +746,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -154,7 +154,7 @@ TEST(PreparedModelTest, executeSync) {
.WillOnce(Invoke(makeExecuteSynchronously(V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -172,7 +172,7 @@ TEST(PreparedModelTest, executeSyncError) {
makeExecuteSynchronously(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -189,7 +189,7 @@ TEST(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -206,7 +206,7 @@ TEST(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -224,7 +224,7 @@ TEST(PreparedModelTest, executeAsync) {
V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -243,7 +243,7 @@ TEST(PreparedModelTest, executeAsyncLaunchError) {
kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -261,7 +261,7 @@ TEST(PreparedModelTest, executeAsyncReturnError) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -278,7 +278,7 @@ TEST(PreparedModelTest, executeAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -295,7 +295,7 @@ TEST(PreparedModelTest, executeAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -314,7 +314,7 @@ TEST(PreparedModelTest, executeAsyncCrash) {
EXPECT_CALL(*mockPreparedModel, execute_1_2(_, _, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -328,7 +328,7 @@ TEST(PreparedModelTest, executeFencedNotSupported) {
PreparedModel::create(mockPreparedModel, /*executeSynchronously=*/true).value();
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -347,7 +347,7 @@ TEST(PreparedModelTest, reusableExecuteSync) {
Invoke(makeExecuteSynchronously(V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -371,7 +371,7 @@ TEST(PreparedModelTest, reusableExecuteSyncError) {
makeExecuteSynchronously(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -392,7 +392,7 @@ TEST(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -413,7 +413,7 @@ TEST(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -436,7 +436,7 @@ TEST(PreparedModelTest, reusableExecuteAsync) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -461,7 +461,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncLaunchError) {
kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -483,7 +483,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncReturnError) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -504,7 +504,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -525,7 +525,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -548,7 +548,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncCrash) {
EXPECT_CALL(*mockPreparedModel, execute_1_2(_, _, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -566,7 +566,7 @@ TEST(PreparedModelTest, reusableExecuteFencedNotSupported) {
PreparedModel::create(mockPreparedModel, /*executeSynchronously=*/true).value();
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -66,8 +66,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -48,18 +48,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -396,7 +396,7 @@ nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
}
nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
}

View File

@@ -187,7 +187,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -135,8 +135,9 @@ PreparedModel::executeAsynchronously(const Request& request, V1_2::MeasureTiming
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -174,10 +175,13 @@ PreparedModel::executeInternal(const Request& request, V1_2::MeasureTiming measu
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
PreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -230,7 +234,9 @@ PreparedModel::executeFencedInternal(const Request& request, const hidl_vec<hidl
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -658,7 +658,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -677,7 +677,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -695,7 +695,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -713,7 +713,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -730,7 +730,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -747,7 +747,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -768,7 +768,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -182,7 +182,7 @@ TEST(PreparedModelTest, executeSync) {
.WillOnce(Invoke(makeExecuteSynchronously(V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -200,7 +200,7 @@ TEST(PreparedModelTest, executeSyncError) {
makeExecuteSynchronously(V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -217,7 +217,7 @@ TEST(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -234,7 +234,7 @@ TEST(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -252,7 +252,7 @@ TEST(PreparedModelTest, executeAsync) {
V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -271,7 +271,7 @@ TEST(PreparedModelTest, executeAsyncLaunchError) {
kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -289,7 +289,7 @@ TEST(PreparedModelTest, executeAsyncReturnError) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -306,7 +306,7 @@ TEST(PreparedModelTest, executeAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -323,7 +323,7 @@ TEST(PreparedModelTest, executeAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -344,7 +344,7 @@ TEST(PreparedModelTest, executeAsyncCrash) {
.WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -366,7 +366,7 @@ TEST(PreparedModelTest, executeFenced) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -396,7 +396,7 @@ TEST(PreparedModelTest, executeFencedCallbackError) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -422,7 +422,7 @@ TEST(PreparedModelTest, executeFencedError) {
makeExecuteFencedReturn(V1_3::ErrorStatus::GENERAL_FAILURE, {}, nullptr)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -439,7 +439,7 @@ TEST(PreparedModelTest, executeFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -456,7 +456,7 @@ TEST(PreparedModelTest, executeFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -475,7 +475,7 @@ TEST(PreparedModelTest, reusableExecuteSync) {
Invoke(makeExecuteSynchronously(V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -499,7 +499,7 @@ TEST(PreparedModelTest, reusableExecuteSyncError) {
makeExecuteSynchronously(V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -520,7 +520,7 @@ TEST(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -541,7 +541,7 @@ TEST(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -564,7 +564,7 @@ TEST(PreparedModelTest, reusableExecuteAsync) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -589,7 +589,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncLaunchError) {
kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -611,7 +611,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncReturnError) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -628,7 +628,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -649,7 +649,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -674,7 +674,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncCrash) {
.WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -702,7 +702,7 @@ TEST(PreparedModelTest, reusableExecuteFenced) {
Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -738,7 +738,7 @@ TEST(PreparedModelTest, reusableExecuteFencedCallbackError) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -768,7 +768,7 @@ TEST(PreparedModelTest, reusableExecuteFencedError) {
makeExecuteFencedReturn(V1_3::ErrorStatus::GENERAL_FAILURE, {}, nullptr)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -789,7 +789,7 @@ TEST(PreparedModelTest, reusableExecuteFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -810,7 +810,7 @@ TEST(PreparedModelTest, reusableExecuteFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -0,0 +1,41 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable ExecutionConfig {
boolean measureTiming;
long loopTimeoutDurationNs;
android.hardware.neuralnetworks.TokenValuePair[] executionHints;
android.hardware.neuralnetworks.ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -36,4 +36,5 @@ package android.hardware.neuralnetworks;
interface IBurst {
android.hardware.neuralnetworks.ExecutionResult executeSynchronously(in android.hardware.neuralnetworks.Request request, in long[] memoryIdentifierTokens, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs);
void releaseMemoryResource(in long memoryIdentifierToken);
android.hardware.neuralnetworks.ExecutionResult executeSynchronouslyWithConfig(in android.hardware.neuralnetworks.Request request, in long[] memoryIdentifierTokens, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs);
}

View File

@@ -43,6 +43,7 @@ interface IDevice {
String getVersionString();
void prepareModel(in android.hardware.neuralnetworks.Model model, in android.hardware.neuralnetworks.ExecutionPreference preference, in android.hardware.neuralnetworks.Priority priority, in long deadlineNs, in ParcelFileDescriptor[] modelCache, in ParcelFileDescriptor[] dataCache, in byte[] token, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
void prepareModelFromCache(in long deadlineNs, in ParcelFileDescriptor[] modelCache, in ParcelFileDescriptor[] dataCache, in byte[] token, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
void prepareModelWithConfig(in android.hardware.neuralnetworks.Model model, in android.hardware.neuralnetworks.PrepareModelConfig config, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
const int BYTE_SIZE_OF_CACHE_TOKEN = 32;
const int MAX_NUMBER_OF_CACHE_FILES = 32;
const int EXTENSION_TYPE_HIGH_BITS_PREFIX = 15;

View File

@@ -37,7 +37,9 @@ interface IPreparedModel {
android.hardware.neuralnetworks.ExecutionResult executeSynchronously(in android.hardware.neuralnetworks.Request request, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs);
android.hardware.neuralnetworks.FencedExecutionResult executeFenced(in android.hardware.neuralnetworks.Request request, in ParcelFileDescriptor[] waitFor, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs, in long durationNs);
android.hardware.neuralnetworks.IBurst configureExecutionBurst();
android.hardware.neuralnetworks.IExecution createReusableExecution(in android.hardware.neuralnetworks.Request request, in boolean measureTiming, in long loopTimeoutDurationNs);
android.hardware.neuralnetworks.IExecution createReusableExecution(in android.hardware.neuralnetworks.Request request, in android.hardware.neuralnetworks.ExecutionConfig config);
android.hardware.neuralnetworks.ExecutionResult executeSynchronouslyWithConfig(in android.hardware.neuralnetworks.Request request, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs);
android.hardware.neuralnetworks.FencedExecutionResult executeFencedWithConfig(in android.hardware.neuralnetworks.Request request, in ParcelFileDescriptor[] waitFor, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs, in long durationNs);
const long DEFAULT_LOOP_TIMEOUT_DURATION_NS = 2000000000;
const long MAXIMUM_LOOP_TIMEOUT_DURATION_NS = 15000000000;
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable PrepareModelConfig {
android.hardware.neuralnetworks.ExecutionPreference preference;
android.hardware.neuralnetworks.Priority priority;
long deadlineNs;
ParcelFileDescriptor[] modelCache;
ParcelFileDescriptor[] dataCache;
byte[] cacheToken;
android.hardware.neuralnetworks.TokenValuePair[] compilationHints;
android.hardware.neuralnetworks.ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable TokenValuePair {
int token;
byte[] value;
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ExtensionNameAndPrefix;
import android.hardware.neuralnetworks.TokenValuePair;
/**
* A type that is used to represent all configuration related to
* an Execution.
*/
@VintfStability
parcelable ExecutionConfig {
/**
* Specifies whether or not to measure duration of the execution.
* For {@link IPreparedModel::executeSynchronouslyWithConfig}, the duration runs from the time
* the driver sees the corresponding call to the execute function to the time the driver returns
* from the function. For {@link IPreparedModel::executeFencedWithConfig}, please refer to
* {@link IPreparedModelCallback} for details.
*/
boolean measureTiming;
/**
* The maximum amount of time in nanoseconds that should be spent
* executing a {@link OperationType::WHILE} operation. If a loop
* condition model does not output false within this duration,
* the execution must be aborted. If -1 is provided, the maximum
* amount of time is {@link DEFAULT_LOOP_TIMEOUT_DURATION_NS}.
* Other negative values are invalid. When provided, the duration
* must not exceed {@link MAXIMUM_LOOP_TIMEOUT_DURATION_NS}.
*/
long loopTimeoutDurationNs;
/**
* A vector of token / value pairs represent vendor specific
* execution hints or metadata. The provided TokenValuePairs must not
* contain the same token twice. The driver must validate the
* data and ignore invalid hints. It is up to the driver to
* decide whether to respect the provided hints or not.
*/
TokenValuePair[] executionHints;
/**
* The mapping between extension names and prefixes of token values.
* The driver must ignore the corresponding execution hint, if
* the extension is not supported.
*/
ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -20,6 +20,10 @@ import android.hardware.neuralnetworks.ExtensionOperandTypeInformation;
/**
* Information about an extension.
*
* The extension can provide zero or more operation types (which are not enumerated), zero or more
* operand types (which are enumerated in {@link Extension::operandTypes}, and compilation and
* execution hints (which are not enumerated).
*/
@VintfStability
parcelable Extension {

View File

@@ -17,7 +17,8 @@
package android.hardware.neuralnetworks;
/**
* The mapping between extension names and prefixes of operand and operation type values.
* The mapping between extension names and prefixes of values like operand and operation type, and
* token in {@link TokenValuePair}.
*
* An operand or operation whose numeric type value is above {@link IDevice::OPERAND_TYPE_BASE_MAX}
* or {@link IDevice::OPERATION_TYPE_BASE_MAX} respectively should be interpreted as an extension

View File

@@ -17,6 +17,7 @@
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ErrorStatus;
import android.hardware.neuralnetworks.ExecutionConfig;
import android.hardware.neuralnetworks.ExecutionResult;
import android.hardware.neuralnetworks.Request;
@@ -68,6 +69,8 @@ interface IBurst {
*
* Only a single execution on a given burst object may be active at any time.
*
* Also see {@link IBurst::executeSynchronouslyWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param memoryIdentifierTokens A list of tokens where each token is a non-negative number
@@ -117,4 +120,13 @@ interface IBurst {
* - INVALID_ARGUMENT if one of the input arguments is invalid
*/
void releaseMemoryResource(in long memoryIdentifierToken);
/**
* For detailed specification, please refer to {@link IBurst::executeSynchronously}. The
* difference between the two methods is that executeSynchronouslyWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeSynchronously.
*/
ExecutionResult executeSynchronouslyWithConfig(in Request request,
in long[] memoryIdentifierTokens, in ExecutionConfig config, in long deadlineNs);
}

View File

@@ -28,6 +28,7 @@ import android.hardware.neuralnetworks.IPreparedModelCallback;
import android.hardware.neuralnetworks.IPreparedModelParcel;
import android.hardware.neuralnetworks.Model;
import android.hardware.neuralnetworks.NumberOfCacheFiles;
import android.hardware.neuralnetworks.PrepareModelConfig;
import android.hardware.neuralnetworks.Priority;
/**
@@ -148,7 +149,7 @@ interface IDevice {
*
* If the device reports that caching is not supported, the user may avoid calling
* IDevice::prepareModelFromCache or providing cache file descriptors to
* IDevice::prepareModel.
* IDevice::prepareModel or IDevice::prepareModelWithConfig.
*
* @return NumberOfCacheFiles structure indicating how many files for model and data cache the
* driver needs to cache a single prepared model. It must be less than or equal to
@@ -302,6 +303,8 @@ interface IDevice {
*
* Multiple threads may call prepareModel on the same model concurrently.
*
* Also see {@link IDevice::prepareModelWithConfig}.
*
* @param model The model to be prepared for execution.
* @param preference Indicates the intended execution behavior of a prepared model.
* @param priority The priority of the prepared model relative to other prepared models owned by
@@ -403,17 +406,17 @@ interface IDevice {
* @param modelCache A vector of file descriptors for the security-sensitive cache. The length
* of the vector must match the numModelCache returned from
* getNumberOfCacheFilesNeeded. The cache file descriptors will be provided in
* the same order as with prepareModel.
* the same order as with prepareModel or prepareModelWithConfig.
* @param dataCache A vector of file descriptors for the constants' cache. The length of the
* vector must match the numDataCache returned from
* getNumberOfCacheFilesNeeded. The cache file descriptors will be provided in
* the same order as with prepareModel.
* the same order as with prepareModel or prepareModelWithConfig.
* @param token A caching token of length BYTE_SIZE_OF_CACHE_TOKEN identifying the prepared
* model. It is the same token provided when saving the cache files with
* prepareModel. Tokens should be chosen to have a low rate of collision for a
* particular application. The driver cannot detect a collision; a collision will
* result in a failed execution or in a successful execution that produces
* incorrect output values.
* prepareModel or prepareModelWithConfig. Tokens should be chosen to have a low
* rate of collision for a particular application. The driver cannot detect a
* collision; a collision will result in a failed execution or in a successful
* execution that produces incorrect output values.
* @param callback A callback object used to return the error status of preparing the model for
* execution and the prepared model if successful, nullptr otherwise. The
* callback object's notify function must be called exactly once, even if the
@@ -429,4 +432,28 @@ interface IDevice {
void prepareModelFromCache(in long deadlineNs, in ParcelFileDescriptor[] modelCache,
in ParcelFileDescriptor[] dataCache, in byte[] token,
in IPreparedModelCallback callback);
/**
* For detailed specification, please refer to {@link IDevice::prepareModel}. The only
* difference between the two methods is that prepareModelWithConfig takes {@link
* PrepareModelConfig} instead of standalone configuration parameters, which allows vendor
* specific compilation metadata to be passed.
*
* @param model The model to be prepared for execution.
* @param config Configuration parameters to prepare the model.
* @param callback A callback object used to return the error status of preparing the model for
* execution and the prepared model if successful, nullptr otherwise. The
* callback object's notify function must be called exactly once, even if the
* model could not be prepared.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
* - DEVICE_UNAVAILABLE if driver is offline or busy
* - GENERAL_FAILURE if there is an unspecified error
* - INVALID_ARGUMENT if one of the input arguments related to preparing the model is
* invalid
* - MISSED_DEADLINE_* if the preparation is aborted because the model cannot be prepared by
* the deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
void prepareModelWithConfig(
in Model model, in PrepareModelConfig config, in IPreparedModelCallback callback);
}

View File

@@ -18,6 +18,7 @@ package android.hardware.neuralnetworks;
import android.hardware.common.NativeHandle;
import android.hardware.neuralnetworks.ErrorStatus;
import android.hardware.neuralnetworks.ExecutionConfig;
import android.hardware.neuralnetworks.ExecutionResult;
import android.hardware.neuralnetworks.FencedExecutionResult;
import android.hardware.neuralnetworks.IBurst;
@@ -68,6 +69,8 @@ interface IPreparedModel {
* Any number of calls to the execute* functions, in any combination, may be made concurrently,
* even on the same IPreparedModel object.
*
* Also see {@link IPreparedModel::executeSynchronouslyWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param measure Specifies whether or not to measure duration of the execution. The duration
@@ -134,6 +137,8 @@ interface IPreparedModel {
* Any number of calls to the execute* functions, in any combination, may be made concurrently,
* even on the same IPreparedModel object.
*
* Also see {@link IPreparedModel::executeFencedWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed. The outputs in the request must have fully specified dimensions.
* @param waitFor A vector of sync fence file descriptors. Execution must not start until all
@@ -201,15 +206,7 @@ interface IPreparedModel {
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param measure Specifies whether or not to measure duration of the execution.
* @param loopTimeoutDurationNs The maximum amount of time in nanoseconds that should be spent
* executing a {@link OperationType::WHILE} operation. If a loop
* condition model does not output false within this duration, the
* computation performed on the returned reusable execution object
* must be aborted. If -1 is provided, the maximum amount
* of time is {@link DEFAULT_LOOP_TIMEOUT_DURATION_NS}. Other
* negative values are invalid. When provided, the duration must
* not exceed {@link MAXIMUM_LOOP_TIMEOUT_DURATION_NS}.
* @param config Specifies the execution configuration parameters.
* @return execution An IExecution object representing a reusable execution that has been
* specialized for a fixed request.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
@@ -218,6 +215,64 @@ interface IPreparedModel {
* - INVALID_ARGUMENT if one of the input arguments is invalid
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
IExecution createReusableExecution(
in Request request, in boolean measureTiming, in long loopTimeoutDurationNs);
IExecution createReusableExecution(in Request request, in ExecutionConfig config);
/**
* For detailed specification, please refer to {@link IPreparedModel::executeSynchronously}. The
* difference between the two methods is that executeSynchronouslyWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeSynchronously.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param config Specifies the execution configuration parameters.
* @param deadlineNs The time by which the execution is expected to complete. The time is
* measured in nanoseconds since boot (as from clock_gettime(CLOCK_BOOTTIME,
* &ts) or ::android::base::boot_clock). If the execution cannot be finished
* by the deadline, the execution may be aborted. Passing -1 means the
* deadline is omitted. Other negative valueggs are invalid.
* @return ExecutionResult parcelable, containing the status of the execution, output shapes and
* timing information.
* - MISSED_DEADLINE_* if the execution is aborted because it cannot be completed by the
* deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
ExecutionResult executeSynchronouslyWithConfig(
in Request request, in ExecutionConfig config, in long deadlineNs);
/**
* For detailed specification, please refer to {@link IPreparedModel::executeFenced}. The
* difference between the two methods is that executeFencedWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeFenced.
*
* @param request The input and output information on which the prepared model is to be
* executed. The outputs in the request must have fully specified dimensions.
* @param waitFor A vector of sync fence file descriptors. Execution must not start until all
* sync fences have been signaled.
* @param config Specifies the execution configuration parameters.
* @param deadlineNs The time by which the execution is expected to complete. The time is
* measured in nanoseconds since boot (as from clock_gettime(CLOCK_BOOTTIME,
* &ts) or ::android::base::boot_clock). If the execution cannot be finished
* by the deadline, the execution may be aborted. Passing -1 means the
* deadline is omitted. Other negative values are invalid.
* @param durationNs The length of time in nanoseconds within which the execution is expected to
* complete after all sync fences in waitFor are signaled. If the execution
* cannot be finished within the duration, the execution may be aborted.
* Passing -1 means the duration is omitted. Other negative values are
* invalid.
* @return The FencedExecutionResult parcelable, containing IFencedExecutionCallback and the
* sync fence.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
* - DEVICE_UNAVAILABLE if driver is offline or busy
* - GENERAL_FAILURE if there is an unspecified error
* - INVALID_ARGUMENT if one of the input arguments is invalid, including fences in error
* states.
* - MISSED_DEADLINE_* if the execution is aborted because it cannot be completed by the
* deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
FencedExecutionResult executeFencedWithConfig(in Request request,
in ParcelFileDescriptor[] waitFor, in ExecutionConfig config, in long deadlineNs,
in long durationNs);
}

View File

@@ -0,0 +1,95 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ExecutionPreference;
import android.hardware.neuralnetworks.ExtensionNameAndPrefix;
import android.hardware.neuralnetworks.Priority;
import android.hardware.neuralnetworks.TokenValuePair;
/**
* A type that is used to represent all configuration needed to
* prepare a model.
*/
@VintfStability
parcelable PrepareModelConfig {
/**
* Indicates the intended execution behavior of a prepared model.
*/
ExecutionPreference preference;
/**
* The priority of the prepared model relative to other prepared
* models owned by the client.
*/
Priority priority;
/**
* The time by which the model is expected to be prepared. The
* time is measured in nanoseconds since boot (as from
* clock_gettime(CLOCK_BOOTTIME, &ts) or
* ::android::base::boot_clock). If the model cannot be prepared
* by the deadline, the preparation may be aborted. Passing -1
* means the deadline is omitted. Other negative values are
* invalid.
*/
long deadlineNs;
/**
* A vector of file descriptors for the security-sensitive cache.
* The length of the vector must either be 0 indicating that
* caching information is not provided, or match the
* numModelCache returned from IDevice::getNumberOfCacheFilesNeeded. The
* cache file descriptors will be provided in the same order when
* retrieving the preparedModel from cache files with
* IDevice::prepareModelFromCache.
*/
ParcelFileDescriptor[] modelCache;
/**
* A vector of file descriptors for the constants' cache. The
* length of the vector must either be 0 indicating that caching
* information is not provided, or match the numDataCache
* returned from IDevice::getNumberOfCacheFilesNeeded. The cache file
* descriptors will be provided in the same order when retrieving
* the preparedModel from cache files with IDevice::prepareModelFromCache.
*/
ParcelFileDescriptor[] dataCache;
/**
* A caching token of length IDevice::BYTE_SIZE_OF_CACHE_TOKEN identifying
* the prepared model. The same token will be provided when
* retrieving the prepared model from the cache files with
* IDevice::prepareModelFromCache. Tokens should be chosen to have a low
* rate of collision for a particular application. The driver
* cannot detect a collision; a collision will result in a failed
* execution or in a successful execution that produces incorrect
* output values. If both modelCache and dataCache are empty
* indicating that caching information is not provided, this
* token must be ignored.
*/
byte[] cacheToken;
/**
* A vector of token / value pairs represent vendor specific
* compilation hints or metadata. The provided TokenValuePairs must not
* contain the same token twice. The driver must validate the
* data and ignore invalid hints. It is up to the driver to
* decide whether to respect the provided hints or not.
*/
TokenValuePair[] compilationHints;
/**
* The mapping between extension names and prefixes of token values.
* The driver must ignore the corresponding compilation hint, if
* the extension is not supported.
*/
ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
/**
* A type that is used to represent a token / byte array data pair.
*/
@VintfStability
parcelable TokenValuePair {
/**
* A 32bit integer token. The token is created by combining the
* extension prefix and enum defined within the extension.
* The low {@link IDevice::EXTENSION_TYPE_LOW_BITS_TYPE} bits of the value
* correspond to the hint within the extension and the high
* {@link IDevice::EXTENSION_TYPE_HIGH_BITS_PREFIX} bits encode the "prefix", which maps
* uniquely to the extension name. The sign bit is always 0.
*
* For example, if a token value is 0x7AAA000B and the corresponding
* {@link ExtensionNameAndPrefix} contains an entry with prefix=0x7AAA and
* name="vendor.test.test_extension", then the token should be interpreted as the hint
* 0x000B of the extension named vendor.test.test_extension.
*/
int token;
/**
* A byte array containing the raw data.
*/
byte[] value;
}

View File

@@ -86,10 +86,12 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
GUARDED_BY(mMutex);
};
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Burst>> create(
std::shared_ptr<aidl_hal::IBurst> burst);
std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel);
Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst);
Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst,
nn::Version featureLevel);
// See IBurst::cacheMemory for information.
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
@@ -97,23 +99,29 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
// See IBurst::execute for information.
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// See IBurst::createReusableExecution for information.
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const aidl_hal::Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
bool measure, int64_t deadline, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:
mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
const std::shared_ptr<aidl_hal::IBurst> kBurst;
const std::shared_ptr<MemoryCache> kMemoryCache;
const nn::Version kFeatureLevel;
};
} // namespace aidl::android::hardware::neuralnetworks::utils

View File

@@ -46,6 +46,10 @@
#include <aidl/android/hardware/neuralnetworks/SymmPerChannelQuantParams.h>
#include <aidl/android/hardware/neuralnetworks/Timing.h>
#ifdef NN_AIDL_V4_OR_ABOVE
#include <aidl/android/hardware/neuralnetworks/TokenValuePair.h>
#endif // NN_AIDL_V4_OR_ABOVE
#include <android/binder_auto_utils.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
@@ -74,7 +78,7 @@ GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams);
GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation);
GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model);
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix);
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues);
GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph);
@@ -97,6 +101,10 @@ GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation);
GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor& handle);
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair);
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<Operation>> unvalidatedConvert(
const std::vector<aidl_hal::Operation>& operations);
@@ -116,6 +124,14 @@ GeneralResult<BufferDesc> convert(const aidl_hal::BufferDesc& bufferDesc);
GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix);
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<aidl_hal::TokenValuePair>& metaData);
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes);
GeneralResult<std::vector<SharedHandle>> convert(
@@ -152,7 +168,7 @@ nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgra
nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
const nn::Model::OperandValues& operandValues);
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix);
const nn::ExtensionNameAndPrefix& extensionNameToPrefix);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority);
nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
@@ -166,6 +182,10 @@ nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::Shared
nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities);
nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension);
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair);
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken);
nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc);
nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType);
@@ -190,6 +210,13 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences);
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions);
nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix);
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<nn::TokenValuePair>& metaData);
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec);

View File

@@ -42,6 +42,7 @@ class Device final : public nn::IDevice {
struct PrivateConstructorTag {};
public:
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Device>> create(
std::string name, std::shared_ptr<aidl_hal::IDevice> device, nn::Version featureLevel);
@@ -67,8 +68,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -63,7 +63,9 @@
#ifdef NN_AIDL_V4_OR_ABOVE
#include <aidl/android/hardware/neuralnetworks/BnExecution.h>
#include <aidl/android/hardware/neuralnetworks/ExecutionConfig.h>
#include <aidl/android/hardware/neuralnetworks/IExecution.h>
#include <aidl/android/hardware/neuralnetworks/PrepareModelConfig.h>
#endif // NN_AIDL_V4_OR_ABOVE
namespace android::nn {

View File

@@ -53,6 +53,9 @@ class InvalidDevice : public BnDevice {
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelFromCache(
int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,

View File

@@ -40,6 +40,7 @@ class PreparedModel final : public nn::IPreparedModel,
struct PrivateConstructorTag {};
public:
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const PreparedModel>> create(
std::shared_ptr<aidl_hal::IPreparedModel> preparedModel, nn::Version featureLevel);
@@ -49,18 +50,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
@@ -68,6 +74,8 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const Request& request, bool measure, int64_t deadline, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,6 +83,8 @@ class PreparedModel final : public nn::IPreparedModel,
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
int64_t deadline, int64_t loopTimeoutDuration,
int64_t timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:

View File

@@ -43,12 +43,16 @@ class BurstExecution final : public nn::IExecution,
static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure,
int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
@@ -64,6 +68,8 @@ class BurstExecution final : public nn::IExecution,
const std::vector<int64_t> kMemoryIdentifierTokens;
const bool kMeasure;
const int64_t kLoopTimeoutDuration;
const std::vector<nn::TokenValuePair> kHints;
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix;
const hal::utils::RequestRelocation kRelocation;
const std::vector<Burst::OptionalCacheHold> kCacheHolds;
};
@@ -149,17 +155,20 @@ void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t i
}
nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
std::shared_ptr<aidl_hal::IBurst> burst) {
std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel) {
if (burst == nullptr) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "aidl_hal::utils::Burst::create must have non-null burst";
}
return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst), featureLevel);
}
Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
: kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst,
nn::Version featureLevel)
: kBurst(std::move(burst)),
kMemoryCache(std::make_shared<MemoryCache>(kBurst)),
kFeatureLevel(featureLevel) {
CHECK(kBurst != nullptr);
}
@@ -170,8 +179,9 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) cons
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -200,14 +210,14 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
memoryIdentifierTokens.push_back(-1);
}
CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
aidlLoopTimeoutDuration, relocation);
aidlLoopTimeoutDuration, hints, extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
int64_t deadline, int64_t loopTimeoutDuration,
int64_t deadline, int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
// Ensure that at most one execution is in flight at any given time.
const bool alreadyInFlight = mExecutionInFlight.test_and_set();
@@ -221,9 +231,21 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
}
ExecutionResult executionResult;
const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
deadline, loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kBurst->executeSynchronouslyWithConfig(
request, memoryIdentifierTokens,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
} else {
const auto ret =
kBurst->executeSynchronously(request, memoryIdentifierTokens, measure, deadline,
loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
}
if (!executionResult.outputSufficientSize) {
auto canonicalOutputShapes =
nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
@@ -241,7 +263,9 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -272,12 +296,15 @@ nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
std::move(memoryIdentifierTokens), aidlMeasure,
aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
aidlLoopTimeoutDuration, hints, extensionNameToPrefix,
std::move(relocation), std::move(holds));
}
nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds) {
if (burst == nullptr) {
@@ -286,13 +313,15 @@ nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
return std::make_shared<const BurstExecution>(
PrivateConstructorTag{}, std::move(burst), std::move(request),
std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
std::move(cacheHolds));
std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, hints,
extensionNameToPrefix, std::move(relocation), std::move(cacheHolds));
}
BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
Request request, std::vector<int64_t> memoryIdentifierTokens,
bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds)
: kBurst(std::move(burst)),
@@ -300,6 +329,8 @@ BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<co
kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
kMeasure(measure),
kLoopTimeoutDuration(loopTimeoutDuration),
kHints(hints),
kExtensionNameToPrefix(extensionNameToPrefix),
kRelocation(std::move(relocation)),
kCacheHolds(std::move(cacheHolds)) {}
@@ -307,7 +338,8 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstEx
const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
kLoopTimeoutDuration, kRelocation);
kLoopTimeoutDuration, kHints, kExtensionNameToPrefix,
kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>

View File

@@ -302,9 +302,9 @@ GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subg
};
}
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
return ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,
};
@@ -506,6 +506,12 @@ GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor&
return std::make_shared<const Handle>(std::move(duplicatedFd));
}
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair) {
return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
}
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
return validatedConvert(capabilities);
}
@@ -562,6 +568,17 @@ GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extens
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
return validatedConvert(memories);
}
GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix) {
return unvalidatedConvert(extensionNameAndPrefix);
}
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<aidl_hal::TokenValuePair>& metaData) {
return validatedConvert(metaData);
}
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes) {
@@ -942,7 +959,7 @@ nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
}
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameToPrefix) {
return ExtensionNameAndPrefix{
.name = extensionNameToPrefix.name,
.prefix = extensionNameToPrefix.prefix,
@@ -1055,6 +1072,11 @@ nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension)
return Extension{.name = extension.name,
.operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes))};
}
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair) {
return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
}
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
return validatedConvert(cacheToken);
@@ -1134,6 +1156,17 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences) {
return validatedConvert(syncFences);
}
nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) {
return unvalidatedConvert(extensionNameToPrefix);
}
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<nn::TokenValuePair>& metaData) {
return validatedConvert(metaData);
}
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions) {
return validatedConvert(extensions);

View File

@@ -215,7 +215,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =
@@ -225,17 +227,28 @@ nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const auto aidlPreference = NN_TRY(convert(preference));
const auto aidlPriority = NN_TRY(convert(priority));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlModelCache = NN_TRY(convert(modelCache));
const auto aidlDataCache = NN_TRY(convert(dataCache));
auto aidlModelCache = NN_TRY(convert(modelCache));
auto aidlDataCache = NN_TRY(convert(dataCache));
const auto aidlToken = NN_TRY(convert(token));
const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>(kFeatureLevel);
const auto scoped = kDeathHandler.protectCallback(cb.get());
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kDevice->prepareModelWithConfig(
aidlModel,
{aidlPreference, aidlPriority, aidlDeadline, std::move(aidlModelCache),
std::move(aidlDataCache), aidlToken, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
cb);
HANDLE_ASTATUS(ret) << "prepareModel failed";
return cb->get();
}
const auto ret = kDevice->prepareModel(aidlModel, aidlPreference, aidlPriority, aidlDeadline,
aidlModelCache, aidlDataCache, aidlToken, cb);
HANDLE_ASTATUS(ret) << "prepareModel failed";
return cb->get();
}

View File

@@ -63,7 +63,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ExecutionWithCachedRequest::compute(const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kPreparedModel->executeInternal(kRequest, kMeasure, aidlDeadline, kLoopTimeoutDuration,
kRelocation);
{}, {}, kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -73,9 +73,9 @@ ExecutionWithCachedRequest::computeFenced(
const auto aidlWaitFor = NN_TRY(convert(waitFor));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
return kPreparedModel->executeFencedInternal(kRequest, aidlWaitFor, kMeasure, aidlDeadline,
kLoopTimeoutDuration,
aidlTimeoutDurationAfterFence, kRelocation);
return kPreparedModel->executeFencedInternal(
kRequest, aidlWaitFor, kMeasure, aidlDeadline, kLoopTimeoutDuration,
aidlTimeoutDurationAfterFence, {}, {}, kRelocation);
}
nn::GeneralResult<std::shared_ptr<const Execution>> Execution::create(

View File

@@ -167,6 +167,31 @@ ndk::ScopedAStatus InvalidDevice::prepareModel(
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus InvalidDevice::prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) {
if (!utils::valid(config.extensionNameToPrefix)) {
callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid extensionNameToPrefix");
}
for (const auto& hint : config.compilationHints) {
auto result = std::find_if(config.extensionNameToPrefix.begin(),
config.extensionNameToPrefix.end(),
[&hint](const ExtensionNameAndPrefix& extension) {
uint16_t prefix = static_cast<uint32_t>(hint.token) >>
IDevice::EXTENSION_TYPE_LOW_BITS_TYPE;
return prefix == extension.prefix;
});
if (result == config.extensionNameToPrefix.end()) {
callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
return toAStatus(ErrorStatus::INVALID_ARGUMENT,
"Invalid token for compilation hints: " + std::to_string(hint.token));
}
}
return prepareModel(model, config.preference, config.priority, config.deadlineNs,
config.modelCache, config.dataCache, config.cacheToken, callback);
}
ndk::ScopedAStatus InvalidDevice::prepareModelFromCache(
int64_t /*deadline*/, const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,

View File

@@ -128,8 +128,9 @@ PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -141,30 +142,46 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Prepare
const auto aidlMeasure = NN_TRY(convert(measure));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration,
relocation);
return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration, hints,
extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
PreparedModel::executeInternal(const Request& request, bool measure, int64_t deadline,
int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
ExecutionResult executionResult;
const auto ret = kPreparedModel->executeSynchronously(request, measure, deadline,
loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronously failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->executeSynchronouslyWithConfig(
request,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronouslyWithConfig failed";
} else {
const auto ret = kPreparedModel->executeSynchronously(
request, measure, deadline, loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronously failed";
}
return handleExecutionResult(executionResult, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
PreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -179,31 +196,45 @@ PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::S
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
return executeFencedInternal(aidlRequest, aidlWaitFor, aidlMeasure, aidlDeadline,
aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence,
relocation);
aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence, hints,
extensionNameToPrefix, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFencedInternal(const Request& request,
const std::vector<ndk::ScopedFileDescriptor>& waitFor,
bool measure, int64_t deadline, int64_t loopTimeoutDuration,
int64_t timeoutDurationAfterFence,
const hal::utils::RequestRelocation& relocation) const {
PreparedModel::executeFencedInternal(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
int64_t deadline, int64_t loopTimeoutDuration, int64_t timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
FencedExecutionResult result;
const auto ret =
kPreparedModel->executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFenced failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->executeFencedWithConfig(
request, waitFor,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFencedWithConfig failed";
} else {
const auto ret = kPreparedModel->executeFenced(request, waitFor, measure, deadline,
loopTimeoutDuration,
timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFenced failed";
}
return handleFencedExecutionResult(result, relocation);
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -217,8 +248,14 @@ nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
std::shared_ptr<IExecution> execution;
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->createReusableExecution(
aidlRequest, aidlMeasure, aidlLoopTimeoutDuration, &execution);
aidlRequest,
{aidlMeasure, aidlLoopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
&execution);
HANDLE_ASTATUS(ret) << "createReusableExecution failed";
return Execution::create(std::move(execution), std::move(relocation));
}
@@ -232,7 +269,7 @@ nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() cons
std::shared_ptr<IBurst> burst;
const auto ret = kPreparedModel->configureExecutionBurst(&burst);
HANDLE_ASTATUS(ret) << "configureExecutionBurst failed";
return Burst::create(std::move(burst));
return Burst::create(std::move(burst), kFeatureLevel);
}
std::any PreparedModel::getUnderlyingResource() const {

View File

@@ -61,7 +61,6 @@ constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<
.powerUsage = std::numeric_limits<float>::max()};
constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles - 1,
.numDataCache = nn::kMaxNumberOfCacheFiles};
constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
std::shared_ptr<MockDevice> createMockDevice() {
@@ -124,6 +123,18 @@ auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
};
}
const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
auto makePreparedModelWithConfigReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
const std::shared_ptr<MockPreparedModel>& preparedModel) {
return [launchStatus, returnStatus, preparedModel](
const Model& /*model*/, const PrepareModelConfig& /*config*/,
const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
};
}
auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
const std::shared_ptr<MockPreparedModel>& preparedModel) {
return [launchStatus, returnStatus, preparedModel](
@@ -560,6 +571,8 @@ TEST_P(DeviceTest, getSupportedOperationsDeadObject) {
}
TEST_P(DeviceTest, prepareModel) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -571,7 +584,7 @@ TEST_P(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -580,6 +593,8 @@ TEST_P(DeviceTest, prepareModel) {
}
TEST_P(DeviceTest, prepareModelLaunchError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -590,7 +605,7 @@ TEST_P(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -598,6 +613,8 @@ TEST_P(DeviceTest, prepareModelLaunchError) {
}
TEST_P(DeviceTest, prepareModelReturnError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -608,7 +625,7 @@ TEST_P(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -616,6 +633,8 @@ TEST_P(DeviceTest, prepareModelReturnError) {
}
TEST_P(DeviceTest, prepareModelNullptrError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -626,7 +645,7 @@ TEST_P(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -634,6 +653,8 @@ TEST_P(DeviceTest, prepareModelNullptrError) {
}
TEST_P(DeviceTest, prepareModelTransportFailure) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -643,7 +664,7 @@ TEST_P(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -651,6 +672,8 @@ TEST_P(DeviceTest, prepareModelTransportFailure) {
}
TEST_P(DeviceTest, prepareModelDeadObject) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -660,7 +683,7 @@ TEST_P(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -668,6 +691,8 @@ TEST_P(DeviceTest, prepareModelDeadObject) {
}
TEST_P(DeviceTest, prepareModelAsyncCrash) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup test
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -681,7 +706,157 @@ TEST_P(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(DeviceTest, prepareModelWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
mockPreparedModel)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
EXPECT_NE(result.value(), nullptr);
}
TEST_P(DeviceTest, prepareModelWithConfigLaunchError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(
ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigReturnError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(
ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigNullptrError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto ret = [&device]() {
DeathMonitor::serviceDied(device->getDeathMonitor());
return ndk::ScopedAStatus::ok();
};
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -32,6 +32,10 @@ class MockBurst final : public BnBurst {
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
(const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadline, ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, releaseMemoryResource, (int64_t memoryIdentifierToken),
(override));
};

View File

@@ -50,6 +50,10 @@ class MockDevice final : public BnDevice {
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback),
(override));
MOCK_METHOD(ndk::ScopedAStatus, prepareModelWithConfig,
(const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback),
(override));
MOCK_METHOD(ndk::ScopedAStatus, prepareModelFromCache,
(int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,

View File

@@ -40,10 +40,19 @@ class MockPreparedModel final : public BnPreparedModel {
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
int64_t duration, FencedExecutionResult* fencedExecutionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
(const Request& request, const ExecutionConfig& config, int64_t deadline,
ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeFencedWithConfig,
(const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadline, int64_t duration,
FencedExecutionResult* fencedExecutionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, configureExecutionBurst, (std::shared_ptr<IBurst> * burst),
(override));
MOCK_METHOD(ndk::ScopedAStatus, createReusableExecution,
(const Request& request, bool measureTiming, int64_t loopTimeoutDuration,
(const Request& request, const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution),
(override));
};

View File

@@ -70,6 +70,21 @@ auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback
class PreparedModelTest : public VersionedAidlUtilsTestBase {};
const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
auto makeFencedExecutionWithConfigResult(
const std::shared_ptr<MockFencedExecutionCallback>& callback) {
return [callback](const Request& /*request*/,
const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
const ExecutionConfig& /*config*/, int64_t /*deadline*/, int64_t /*duration*/,
FencedExecutionResult* fencedExecutionResult) {
*fencedExecutionResult = FencedExecutionResult{.callback = callback,
.syncFence = ndk::ScopedFileDescriptor(-1)};
return ndk::ScopedAStatus::ok();
};
}
} // namespace
TEST_P(PreparedModelTest, invalidPreparedModel) {
@@ -82,6 +97,8 @@ TEST_P(PreparedModelTest, invalidPreparedModel) {
}
TEST_P(PreparedModelTest, executeSync) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -96,7 +113,7 @@ TEST_P(PreparedModelTest, executeSync) {
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -104,6 +121,8 @@ TEST_P(PreparedModelTest, executeSync) {
}
TEST_P(PreparedModelTest, executeSyncError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -112,7 +131,7 @@ TEST_P(PreparedModelTest, executeSyncError) {
.WillOnce(Invoke(makeGeneralFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -120,6 +139,8 @@ TEST_P(PreparedModelTest, executeSyncError) {
}
TEST_P(PreparedModelTest, executeSyncTransportFailure) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -128,7 +149,7 @@ TEST_P(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -136,6 +157,8 @@ TEST_P(PreparedModelTest, executeSyncTransportFailure) {
}
TEST_P(PreparedModelTest, executeSyncDeadObject) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -144,7 +167,7 @@ TEST_P(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -152,6 +175,8 @@ TEST_P(PreparedModelTest, executeSyncDeadObject) {
}
TEST_P(PreparedModelTest, executeFenced) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -165,7 +190,7 @@ TEST_P(PreparedModelTest, executeFenced) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -181,6 +206,8 @@ TEST_P(PreparedModelTest, executeFenced) {
}
TEST_P(PreparedModelTest, executeFencedCallbackError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -195,7 +222,7 @@ TEST_P(PreparedModelTest, executeFencedCallbackError) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -211,6 +238,8 @@ TEST_P(PreparedModelTest, executeFencedCallbackError) {
}
TEST_P(PreparedModelTest, executeFencedError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -219,7 +248,7 @@ TEST_P(PreparedModelTest, executeFencedError) {
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -227,6 +256,8 @@ TEST_P(PreparedModelTest, executeFencedError) {
}
TEST_P(PreparedModelTest, executeFencedTransportFailure) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -235,7 +266,7 @@ TEST_P(PreparedModelTest, executeFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -243,6 +274,8 @@ TEST_P(PreparedModelTest, executeFencedTransportFailure) {
}
TEST_P(PreparedModelTest, executeFencedDeadObject) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -251,7 +284,7 @@ TEST_P(PreparedModelTest, executeFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -276,7 +309,7 @@ TEST_P(PreparedModelTest, reusableExecuteSync) {
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -300,7 +333,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncError) {
.WillOnce(Invoke(makeGeneralFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -322,7 +355,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -344,7 +377,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -372,7 +405,7 @@ TEST_P(PreparedModelTest, reusableExecuteFenced) {
.WillRepeatedly(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -410,7 +443,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedCallbackError) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -440,7 +473,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedError) {
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -462,7 +495,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -484,7 +517,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -495,6 +528,206 @@ TEST_P(PreparedModelTest, reusableExecuteFencedDeadObject) {
EXPECT_EQ(computeResult.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, executeSyncWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockExecutionResult = ExecutionResult{
.outputSufficientSize = true,
.outputShapes = {},
.timing = kNoTiming,
};
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(
DoAll(SetArgPointee<3>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
EXPECT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
}
TEST_P(PreparedModelTest, executeSyncWithConfigError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(Invoke(makeGeneralFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeSyncWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeSyncWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, executeFencedWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockCallback = MockFencedExecutionCallback::create();
EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
.Times(1)
.WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
const auto& [syncFence, callback] = result.value();
EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
ASSERT_NE(callback, nullptr);
// get results from callback
const auto callbackResult = callback();
ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
<< callbackResult.error().message;
}
TEST_P(PreparedModelTest, executeFencedWithConfigCallbackError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockCallback = MockFencedExecutionCallback::create();
EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
.Times(1)
.WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
Invoke(makeStatusOk))));
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
const auto& [syncFence, callback] = result.value();
EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
ASSERT_NE(callback, nullptr);
// verify callback failure
const auto callbackResult = callback();
ASSERT_FALSE(callbackResult.has_value());
EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, configureExecutionBurst) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
@@ -567,13 +800,13 @@ TEST_P(PreparedModelTest, createReusableExecution) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto mockExecution = ndk::SharedRefBase::make<MockExecution>();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(DoAll(SetArgPointee<3>(mockExecution), Invoke(makeStatusOk)));
.WillOnce(DoAll(SetArgPointee<2>(mockExecution), Invoke(makeStatusOk)));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -586,13 +819,13 @@ TEST_P(PreparedModelTest, createReusableExecutionError) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -604,13 +837,13 @@ TEST_P(PreparedModelTest, createReusableExecutionTransportFailure) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -622,13 +855,13 @@ TEST_P(PreparedModelTest, createReusableExecutionDeadObject) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -63,6 +63,8 @@ struct TestConfig {
// it is skipped. The field is set to true by default and is set to false in
// quantization coupling tests to suppress skipping a test
bool reportSkipping;
// `useConfig` indicates if a test should use execute*WithConfig functions for the execution.
bool useConfig;
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable)
: executor(executor),
@@ -70,7 +72,8 @@ struct TestConfig {
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(true) {}
reportSkipping(true),
useConfig(false) {}
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable, bool reportSkipping)
: executor(executor),
@@ -78,7 +81,17 @@ struct TestConfig {
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(reportSkipping) {}
reportSkipping(reportSkipping),
useConfig(false) {}
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable, bool reportSkipping, bool useConfig)
: executor(executor),
measureTiming(measureTiming),
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(reportSkipping),
useConfig(useConfig) {}
};
std::string toString(OutputType type) {
@@ -100,7 +113,8 @@ std::string toString(const TestConfig& config) {
<< ", .measureTiming=" << (config.measureTiming ? "true" : "false")
<< ", .outputType=" << toString(config.outputType)
<< ", .memoryType=" << toString(config.memoryType)
<< ", .reusable=" << (config.reusable ? "true" : "false") << "}";
<< ", .reusable=" << (config.reusable ? "true" : "false")
<< ", .useConfig=" << (config.useConfig ? "true" : "false") << "}";
return ss.str();
}
@@ -587,8 +601,8 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
std::shared_ptr<IExecution> execution;
if (testConfig.reusable) {
const auto ret = preparedModel->createReusableExecution(request, testConfig.measureTiming,
loopTimeoutDurationNs, &execution);
const auto ret = preparedModel->createReusableExecution(
request, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}}, &execution);
ASSERT_TRUE(ret.isOk()) << static_cast<nn::ErrorStatus>(ret.getServiceSpecificError());
ASSERT_NE(nullptr, execution.get());
}
@@ -607,6 +621,10 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
::ndk::ScopedAStatus ret;
if (testConfig.reusable) {
ret = execution->executeSynchronously(kNoDeadline, &executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeSynchronouslyWithConfig(
request, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, &executionResult);
} else {
ret = preparedModel->executeSynchronously(request, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
@@ -649,9 +667,16 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
ExecutionResult executionResult;
// execute
ret = burst->executeSynchronously(request, slots, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
&executionResult);
if (testConfig.useConfig) {
ret = burst->executeSynchronouslyWithConfig(
request, slots,
{testConfig.measureTiming, loopTimeoutDurationNs, {}, {}}, kNoDeadline,
&executionResult);
} else {
ret = burst->executeSynchronously(request, slots, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
&executionResult);
}
ASSERT_TRUE(ret.isOk() || ret.getExceptionCode() == EX_SERVICE_SPECIFIC)
<< ret.getDescription();
if (ret.isOk()) {
@@ -680,6 +705,10 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
::ndk::ScopedAStatus ret;
if (testConfig.reusable) {
ret = execution->executeFenced({}, kNoDeadline, kNoDuration, &executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeFencedWithConfig(
request, {}, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, kNoDuration, &executionResult);
} else {
ret = preparedModel->executeFenced(request, {}, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
@@ -697,9 +726,19 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
waitFor.emplace_back(dupFd);
// If a sync fence is returned, try start another run waiting for the sync
// fence.
ret = preparedModel->executeFenced(request, waitFor, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
kNoDuration, &executionResult);
if (testConfig.reusable) {
ret = execution->executeFenced(waitFor, kNoDeadline, kNoDuration,
&executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeFencedWithConfig(
request, waitFor,
{testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, kNoDuration, &executionResult);
} else {
ret = preparedModel->executeFenced(
request, waitFor, testConfig.measureTiming, kNoDeadline,
loopTimeoutDurationNs, kNoDuration, &executionResult);
}
ASSERT_TRUE(ret.isOk());
waitForSyncFence(executionResult.syncFence.get());
}
@@ -830,11 +869,13 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
std::vector<Executor> executorList;
std::vector<MemoryType> memoryTypeList;
std::vector<bool> reusableList = {false};
std::vector<bool> useConfigList = {false};
int deviceVersion;
ASSERT_TRUE(device->getInterfaceVersion(&deviceVersion).isOk());
if (deviceVersion >= kMinAidlLevelForFL8) {
reusableList.push_back(true);
useConfigList.push_back(true);
}
switch (testKind) {
@@ -879,11 +920,14 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
for (const Executor executor : executorList) {
for (const MemoryType memoryType : memoryTypeList) {
for (const bool reusable : reusableList) {
if (executor == Executor::BURST && reusable) continue;
const TestConfig testConfig(executor, measureTiming, outputType, memoryType,
reusable);
SCOPED_TRACE(toString(testConfig));
EvaluatePreparedModel(device, preparedModel, testModel, testConfig);
for (const bool useConfig : useConfigList) {
if ((useConfig || executor == Executor::BURST) && reusable) continue;
const TestConfig testConfig(executor, measureTiming, outputType,
memoryType, reusable,
/*reportSkipping=*/true, useConfig);
SCOPED_TRACE(toString(testConfig));
EvaluatePreparedModel(device, preparedModel, testModel, testConfig);
}
}
}
}
@@ -942,6 +986,13 @@ void Execute(const std::shared_ptr<IDevice>& device, const TestModel& testModel,
createPreparedModel(device, model, &preparedModel);
if (preparedModel == nullptr) return;
EvaluatePreparedModel(device, preparedModel, testModel, testKind);
int32_t deviceVersion;
ASSERT_TRUE(device->getInterfaceVersion(&deviceVersion).isOk());
if (deviceVersion >= kMinAidlLevelForFL8) {
createPreparedModel(device, model, &preparedModel, /*reportSkipping*/ true,
/*useConfig*/ true);
EvaluatePreparedModel(device, preparedModel, testModel, testKind);
}
} break;
case TestKind::QUANTIZATION_COUPLING: {
ASSERT_TRUE(testModel.hasQuant8CoupledOperands());

View File

@@ -204,11 +204,23 @@ class InvalidPreparedModel : public BnPreparedModel {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request&, const ExecutionConfig&,
int64_t, ExecutionResult*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus executeFencedWithConfig(const Request&,
const std::vector<ndk::ScopedFileDescriptor>&,
const ExecutionConfig&, int64_t, int64_t,
FencedExecutionResult*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus createReusableExecution(const aidl_hal::Request&, bool, int64_t,
ndk::ScopedAStatus createReusableExecution(const aidl_hal::Request&, const ExecutionConfig&,
std::shared_ptr<aidl_hal::IExecution>*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));

View File

@@ -77,6 +77,28 @@ static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const s
ASSERT_EQ(nullptr, preparedModel.get());
}
static void validatePrepareModelWithConfig(const std::shared_ptr<IDevice>& device,
const std::string& message, const Model& model,
ExecutionPreference preference, Priority priority) {
SCOPED_TRACE(message + " [prepareModelWithConfig]");
std::shared_ptr<PreparedModelCallback> preparedModelCallback =
ndk::SharedRefBase::make<PreparedModelCallback>();
const auto prepareLaunchStatus = device->prepareModelWithConfig(
model, {preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken, {}, {}},
preparedModelCallback);
ASSERT_FALSE(prepareLaunchStatus.isOk());
ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
preparedModelCallback->wait();
ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
ASSERT_EQ(nullptr, preparedModel.get());
}
static bool validExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
@@ -103,6 +125,13 @@ static void validate(const std::shared_ptr<IDevice>& device, const std::string&
}
validatePrepareModel(device, message, model, preference, priority);
int32_t aidlVersion;
ASSERT_TRUE(device->getInterfaceVersion(&aidlVersion).isOk());
if (aidlVersion >= kMinAidlLevelForFL8) {
// prepareModelWithConfig must satisfy all requirements enforced by prepareModel.
validatePrepareModelWithConfig(device, message, model, preference, priority);
}
}
static uint32_t addOperand(Model* model) {

View File

@@ -45,7 +45,7 @@ static void validateReusableExecution(const std::shared_ptr<IPreparedModel>& pre
{
SCOPED_TRACE(message + " [createReusableExecution]");
const auto createStatus = preparedModel->createReusableExecution(
request, measure, kOmittedTimeoutDuration, &execution);
request, {measure, kOmittedTimeoutDuration, {}, {}}, &execution);
if (!createStatus.isOk()) {
ASSERT_EQ(createStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(createStatus.getServiceSpecificError()),
@@ -149,10 +149,59 @@ static void validate(const std::shared_ptr<IPreparedModel>& preparedModel,
int32_t aidlVersion;
ASSERT_TRUE(preparedModel->getInterfaceVersion(&aidlVersion).isOk());
if (aidlVersion < kMinAidlLevelForFL8) {
return;
}
// validate reusable execution
if (aidlVersion >= kMinAidlLevelForFL8) {
validateReusableExecution(preparedModel, message, request, measure);
validateReusableExecution(preparedModel, message, request, measure);
// synchronous with empty hints
{
SCOPED_TRACE(message + " [executeSynchronouslyWithConfig]");
ExecutionResult executionResult;
const auto executeStatus = preparedModel->executeSynchronouslyWithConfig(
request, {measure, kOmittedTimeoutDuration, {}, {}}, kNoDeadline, &executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
// fenced with empty hints
{
SCOPED_TRACE(message + " [executeFencedWithConfig]");
FencedExecutionResult executionResult;
const auto executeStatus = preparedModel->executeFencedWithConfig(
request, {}, {false, kOmittedTimeoutDuration, {}, {}}, kNoDeadline, kNoDuration,
&executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
// burst with empty hints
{
SCOPED_TRACE(message + " [burst executeSynchronouslyWithConfig]");
// create burst
std::shared_ptr<IBurst> burst;
auto ret = preparedModel->configureExecutionBurst(&burst);
ASSERT_TRUE(ret.isOk()) << ret.getDescription();
ASSERT_NE(nullptr, burst.get());
// use -1 for all memory identifier tokens
const std::vector<int64_t> slots(request.pools.size(), -1);
ExecutionResult executionResult;
const auto executeStatus = burst->executeSynchronouslyWithConfig(
request, slots, {measure, kOmittedTimeoutDuration, {}, {}}, kNoDeadline,
&executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
}

View File

@@ -41,7 +41,8 @@ using implementation::PreparedModelCallback;
// internal helper function
void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& model,
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping) {
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping,
bool useConfig) {
ASSERT_NE(nullptr, preparedModel);
*preparedModel = nullptr;
@@ -56,11 +57,25 @@ void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& mo
// launch prepare model
const std::shared_ptr<PreparedModelCallback> preparedModelCallback =
ndk::SharedRefBase::make<PreparedModelCallback>();
const auto prepareLaunchStatus =
device->prepareModel(model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority,
kNoDeadline, {}, {}, kEmptyCacheToken, preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
if (useConfig) {
const auto prepareLaunchStatus =
device->prepareModelWithConfig(model,
{ExecutionPreference::FAST_SINGLE_ANSWER,
kDefaultPriority,
kNoDeadline,
{},
{},
kEmptyCacheToken,
{},
{}},
preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
} else {
const auto prepareLaunchStatus = device->prepareModel(
model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, kNoDeadline, {},
{}, kEmptyCacheToken, preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
}
// retrieve prepared model
preparedModelCallback->wait();
const ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();

View File

@@ -51,8 +51,8 @@ std::string printNeuralNetworksAidlTest(
// Create an IPreparedModel object. If the model cannot be prepared,
// "preparedModel" will be nullptr instead.
void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& model,
std::shared_ptr<IPreparedModel>* preparedModel,
bool reportSkipping = true);
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping = true,
bool useConfig = false);
enum class Executor { SYNC, BURST, FENCED };

View File

@@ -46,6 +46,11 @@ class Burst : public BnBurst {
bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus executeSynchronouslyWithConfig(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadlineNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
class ThreadSafeMemoryCache {

View File

@@ -31,6 +31,7 @@
#include <aidl/android/hardware/neuralnetworks/IPreparedModelParcel.h>
#include <aidl/android/hardware/neuralnetworks/Model.h>
#include <aidl/android/hardware/neuralnetworks/NumberOfCacheFiles.h>
#include <aidl/android/hardware/neuralnetworks/PrepareModelConfig.h>
#include <aidl/android/hardware/neuralnetworks/Priority.h>
#include <android/binder_auto_utils.h>
#include <nnapi/IDevice.h>
@@ -72,6 +73,9 @@ class Device : public BnDevice {
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
protected:
const ::android::nn::SharedDevice kDevice;

View File

@@ -51,9 +51,17 @@ class PreparedModel : public BnPreparedModel {
int64_t loopTimeoutDurationNs, int64_t durationNs,
FencedExecutionResult* executionResult) override;
ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>* burst) override;
ndk::ScopedAStatus createReusableExecution(const Request& request, bool measureTiming,
int64_t loopTimeoutDurationNs,
ndk::ScopedAStatus createReusableExecution(const Request& request,
const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution) override;
ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request& request,
const ExecutionConfig& config,
int64_t deadlineNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus executeFencedWithConfig(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
FencedExecutionResult* executionResult) override;
::android::nn::SharedPreparedModel getUnderlyingPreparedModel() const;

View File

@@ -93,7 +93,8 @@ std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached(
nn::ExecutionResult<ExecutionResult> executeSynchronously(
const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request,
const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs) {
int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
if (request.pools.size() != memoryIdentifierTokens.size()) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
<< "request.pools.size() != memoryIdentifierTokens.size()";
@@ -107,11 +108,13 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache);
const auto result =
burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
@@ -155,7 +158,24 @@ ndk::ScopedAStatus Burst::executeSynchronously(const Request& request,
ExecutionResult* executionResult) {
auto result =
adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens,
measureTiming, deadlineNs, loopTimeoutDurationNs);
measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {});
if (!result.has_value()) {
auto [message, code, _] = std::move(result).error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(
*kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming,
deadlineNs, config.loopTimeoutDurationNs, config.executionHints,
config.extensionNameToPrefix);
if (!result.has_value()) {
auto [message, code, _] = std::move(result).error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);

View File

@@ -148,13 +148,14 @@ void notify(IPreparedModelCallback* callback, PrepareModelResult result) {
}
}
nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Executor& executor,
const Model& model, ExecutionPreference preference,
Priority priority, int64_t deadlineNs,
const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) {
nn::GeneralResult<void> prepareModel(
const nn::SharedDevice& device, const Executor& executor, const Model& model,
ExecutionPreference preference, Priority priority, int64_t deadlineNs,
const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache, const std::vector<uint8_t>& token,
const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
const std::shared_ptr<IPreparedModelCallback>& callback) {
if (callback.get() == nullptr) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
}
@@ -166,12 +167,16 @@ nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Execu
auto nnModelCache = NN_TRY(convertInput(modelCache));
auto nnDataCache = NN_TRY(convertInput(dataCache));
const auto nnToken = NN_TRY(convertCacheToken(token));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
Task task = [device, nnModel = std::move(nnModel), nnPreference, nnPriority, nnDeadline,
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline,
nnModelCache, nnDataCache, nnToken);
nnToken, nnHints = std::move(nnHints),
nnExtensionNameToPrefix = std::move(nnExtensionNameToPrefix), callback] {
auto result =
device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline, nnModelCache,
nnDataCache, nnToken, nnHints, nnExtensionNameToPrefix);
notify(callback.get(), std::move(result));
};
executor(std::move(task), nnDeadline);
@@ -273,8 +278,9 @@ ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) {
const auto result = adapter::prepareModel(kDevice, kExecutor, model, preference, priority,
deadlineNs, modelCache, dataCache, token, callback);
const auto result =
adapter::prepareModel(kDevice, kExecutor, model, preference, priority, deadlineNs,
modelCache, dataCache, token, {}, {}, callback);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -301,4 +307,21 @@ ndk::ScopedAStatus Device::prepareModelFromCache(
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus Device::prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) {
const auto result = adapter::prepareModel(
kDevice, kExecutor, model, config.preference, config.priority, config.deadlineNs,
config.modelCache, config.dataCache, config.cacheToken, config.compilationHints,
config.extensionNameToPrefix, callback);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
callback->notify(aidlCode, nullptr);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
return ndk::ScopedAStatus::ok();
}
} // namespace aidl::android::hardware::neuralnetworks::adapter

View File

@@ -118,17 +118,20 @@ nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationN
return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
}
nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedModel& preparedModel,
const Request& request,
bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs) {
nn::ExecutionResult<ExecutionResult> executeSynchronously(
const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
int64_t deadlineNs, int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
const auto result =
preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
@@ -147,16 +150,21 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedMod
nn::GeneralResult<FencedExecutionResult> executeFenced(
const nn::IPreparedModel& preparedModel, const Request& request,
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs) {
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
auto [syncFence, executeFencedInfoCallback] = NN_TRY(preparedModel.executeFenced(
nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration));
nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration,
nnHints, nnExtensionNameToPrefix));
ndk::ScopedFileDescriptor fileDescriptor;
if (syncFence.hasFd()) {
@@ -171,11 +179,16 @@ nn::GeneralResult<FencedExecutionResult> executeFenced(
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
int64_t loopTimeoutDurationNs) {
int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration);
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
}
nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IExecution& execution,
@@ -231,7 +244,7 @@ ndk::ScopedAStatus PreparedModel::executeSynchronously(const Request& request, b
int64_t loopTimeoutDurationNs,
ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(*kPreparedModel, request, measureTiming, deadlineNs,
loopTimeoutDurationNs);
loopTimeoutDurationNs, {}, {});
if (!result.has_value()) {
const auto& [message, code, _] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -247,7 +260,41 @@ ndk::ScopedAStatus PreparedModel::executeFenced(
bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
FencedExecutionResult* executionResult) {
auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, measureTiming,
deadlineNs, loopTimeoutDurationNs, durationNs);
deadlineNs, loopTimeoutDurationNs, durationNs, {}, {});
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::executeSynchronouslyWithConfig(const Request& request,
const ExecutionConfig& config,
int64_t deadlineNs,
ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(
*kPreparedModel, request, config.measureTiming, deadlineNs,
config.loopTimeoutDurationNs, config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code, _] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::executeFencedWithConfig(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
FencedExecutionResult* executionResult) {
auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, config.measureTiming,
deadlineNs, config.loopTimeoutDurationNs, durationNs,
config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -275,11 +322,11 @@ nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
}
ndk::ScopedAStatus PreparedModel::createReusableExecution(const Request& request,
bool measureTiming,
int64_t loopTimeoutDurationNs,
const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution) {
auto result = adapter::createReusableExecution(*kPreparedModel, request, measureTiming,
loopTimeoutDurationNs);
auto result = adapter::createReusableExecution(
*kPreparedModel, request, config.measureTiming, config.loopTimeoutDurationNs,
config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);

View File

@@ -250,7 +250,7 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> Burst:
nn::MeasureTiming canonicalMeasure = NN_TRY(nn::convert(measure));
const auto [outputShapes, timing] =
NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure, {}, {}));
NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure, {}, {}, {}, {}));
return std::make_pair(NN_TRY(V1_2::utils::convert(outputShapes)),
NN_TRY(V1_2::utils::convert(timing)));

View File

@@ -135,7 +135,7 @@ nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Execu
Task task = [device, nnModel = std::move(nnModel), executor, callback] {
auto result = device->prepareModel(nnModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -155,8 +155,8 @@ nn::GeneralResult<void> prepareModel_1_1(const nn::SharedDevice& device, const E
const auto nnPreference = NN_TRY(convertInput(preference));
Task task = [device, nnModel = std::move(nnModel), nnPreference, executor, callback] {
auto result =
device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {}, {}, {}, {});
auto result = device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {}, {}, {},
{}, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -185,7 +185,7 @@ nn::GeneralResult<void> prepareModel_1_2(const nn::SharedDevice& device, const E
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, executor, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {},
nnModelCache, nnDataCache, nnToken);
nnModelCache, nnDataCache, nnToken, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -215,7 +215,7 @@ nn::GeneralResult<void> prepareModel_1_3(
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, executor, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline,
nnModelCache, nnDataCache, nnToken);
nnModelCache, nnDataCache, nnToken, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), nnDeadline);

View File

@@ -159,7 +159,7 @@ nn::GeneralResult<void> execute(const nn::SharedPreparedModel& preparedModel,
}
Task task = [preparedModel, nnRequest = std::move(nnRequest), callback] {
auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {});
auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), {});
@@ -185,7 +185,7 @@ nn::GeneralResult<void> execute_1_2(const nn::SharedPreparedModel& preparedModel
}
Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, callback] {
auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {});
auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), {});
@@ -216,8 +216,8 @@ nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel
Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, nnDeadline,
nnLoopTimeoutDuration, callback] {
auto result =
preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration);
auto result = preparedModel->execute(nnRequest, nnMeasure, nnDeadline,
nnLoopTimeoutDuration, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), nnDeadline);
@@ -232,7 +232,7 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> execut
const auto nnMeasure = NN_TRY(convertInput(measure));
const auto [outputShapes, timing] =
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}));
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {}));
auto hidlOutputShapes = NN_TRY(V1_2::utils::convert(outputShapes));
const auto hidlTiming = NN_TRY(V1_2::utils::convert(timing));
@@ -248,8 +248,8 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> execut
const auto nnDeadline = NN_TRY(convertInput(deadline));
const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
const auto [outputShapes, timing] =
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration));
const auto [outputShapes, timing] = NN_TRY(preparedModel->execute(
nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {}));
auto hidlOutputShapes = NN_TRY(V1_3::utils::convert(outputShapes));
const auto hidlTiming = NN_TRY(V1_3::utils::convert(timing));
@@ -293,8 +293,9 @@ nn::GeneralResult<std::pair<hidl_handle, sp<V1_3::IFencedExecutionCallback>>> ex
const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
const auto nnDuration = NN_TRY(convertInput(duration));
auto [syncFence, executeFencedCallback] = NN_TRY(preparedModel->executeFenced(
nnRequest, nnWaitFor, nnMeasure, nnDeadline, nnLoopTimeoutDuration, nnDuration));
auto [syncFence, executeFencedCallback] =
NN_TRY(preparedModel->executeFenced(nnRequest, nnWaitFor, nnMeasure, nnDeadline,
nnLoopTimeoutDuration, nnDuration, {}, {}));
auto hidlSyncFence = NN_TRY(V1_3::utils::convert(syncFence.getSharedHandle()));
auto hidlExecuteFencedCallback = sp<FencedExecutionCallback>::make(executeFencedCallback);

View File

@@ -33,12 +33,15 @@ class InvalidBurst final : public nn::IBurst {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
};
} // namespace android::hardware::neuralnetworks::utils

View File

@@ -52,8 +52,9 @@ class InvalidDevice final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -31,18 +31,23 @@ class InvalidPreparedModel final : public nn::IPreparedModel {
public:
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -48,18 +48,23 @@ class ResilientBurst final : public nn::IBurst,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
private:
bool isValidInternal() const EXCLUDES(mMutex);
nn::GeneralResult<nn::SharedExecution> createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
const Factory kMakeBurst;
mutable std::mutex mMutex;

View File

@@ -65,8 +65,9 @@ class ResilientDevice final : public nn::IDevice,
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
@@ -83,7 +84,9 @@ class ResilientDevice final : public nn::IDevice,
nn::GeneralResult<nn::SharedPreparedModel> prepareModelInternal(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCacheInternal(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const;

View File

@@ -49,18 +49,23 @@ class ResilientPreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
@@ -70,7 +75,9 @@ class ResilientPreparedModel final : public nn::IPreparedModel,
bool isValidInternal() const EXCLUDES(mMutex);
nn::GeneralResult<nn::SharedExecution> createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& metaData,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurstInternal() const;
const Factory kMakePreparedModel;

View File

@@ -34,13 +34,17 @@ InvalidBurst::OptionalCacheHold InvalidBurst::cacheMemory(
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> InvalidBurst::execute(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidBurst";
}
nn::GeneralResult<nn::SharedExecution> InvalidBurst::createReusableExecution(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidBurst";
}

View File

@@ -84,7 +84,9 @@ nn::GeneralResult<nn::SharedPreparedModel> InvalidDevice::prepareModel(
const nn::Model& /*model*/, nn::ExecutionPreference /*preference*/,
nn::Priority /*priority*/, nn::OptionalTimePoint /*deadline*/,
const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidDevice";
}

View File

@@ -27,9 +27,12 @@
namespace android::hardware::neuralnetworks::utils {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
InvalidPreparedModel::execute(const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
InvalidPreparedModel::execute(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}
@@ -38,13 +41,17 @@ InvalidPreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}
nn::GeneralResult<nn::SharedExecution> InvalidPreparedModel::createReusableExecution(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}

View File

@@ -105,37 +105,49 @@ ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> ResilientBurst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const auto fn = [&request, measure, deadline, loopTimeoutDuration](const nn::IBurst& burst) {
return burst.execute(request, measure, deadline, loopTimeoutDuration);
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, measure, deadline, loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IBurst& burst) {
return burst.execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientExecution::Factory makeExecution =
[burst = std::move(self), request, measure, loopTimeoutDuration] {
return burst->createReusableExecutionInternal(request, measure, loopTimeoutDuration);
ResilientExecution::Factory makeExecution = [burst = std::move(self), request, measure,
loopTimeoutDuration, &hints,
&extensionNameToPrefix] {
return burst->createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return ResilientExecution::create(std::move(makeExecution));
#else
return createReusableExecutionInternal(request, measure, loopTimeoutDuration);
return createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
#endif
}
nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidExecution>();
}
const auto fn = [&request, measure, &loopTimeoutDuration](const nn::IBurst& burst) {
return burst.createReusableExecution(request, measure, loopTimeoutDuration);
const auto fn = [&request, measure, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IBurst& burst) {
return burst.createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}

View File

@@ -179,19 +179,21 @@ nn::GeneralResult<std::vector<bool>> ResilientDevice::getSupportedOperations(
nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), model,
preference, priority, deadline, modelCache,
dataCache, token] {
dataCache, token, hints, extensionNameToPrefix] {
return device->prepareModelInternal(model, preference, priority, deadline, modelCache,
dataCache, token);
dataCache, token, hints, extensionNameToPrefix);
};
return ResilientPreparedModel::create(std::move(makePreparedModel));
#else
return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache,
token);
return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache, token,
hints, extensionNameToPrefix);
#endif
}
@@ -234,14 +236,16 @@ bool ResilientDevice::isValidInternal() const {
nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelInternal(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidPreparedModel>();
}
const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache,
&token](const nn::IDevice& device) {
const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache, &token,
&hints, &extensionNameToPrefix](const nn::IDevice& device) {
return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
token);
token, hints, extensionNameToPrefix);
};
return protect(*this, fn, /*blocking=*/false);
}

View File

@@ -104,43 +104,53 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const auto fn = [&request, measure, &deadline,
&loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
ResilientPreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, measure, &deadline, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
ResilientPreparedModel::executeFenced(const nn::Request& request,
const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
ResilientPreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
&timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
&timeoutDurationAfterFence, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
timeoutDurationAfterFence);
timeoutDurationAfterFence, hints, extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<nn::SharedExecution> ResilientPreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientExecution::Factory makeExecution =
[preparedModel = std::move(self), request, measure, loopTimeoutDuration] {
return preparedModel->createReusableExecutionInternal(request, measure, loopTimeoutDuration);
ResilientExecution::Factory makeExecution = [preparedModel = std::move(self), request, measure,
loopTimeoutDuration, hints,
extensionNameToPrefix] {
return preparedModel->createReusableExecutionInternal(request, measure, loopTimeoutDuration,
hints, extensionNameToPrefix);
};
return ResilientExecution::create(std::move(makeExecution));
#else
return createReusableExecutionInternal(request, measure, loopTimeoutDuration);
return createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
#endif
}
@@ -159,13 +169,16 @@ nn::GeneralResult<nn::SharedBurst> ResilientPreparedModel::configureExecutionBur
nn::GeneralResult<nn::SharedExecution> ResilientPreparedModel::createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidExecution>();
}
const auto fn = [&request, measure,
&loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
return preparedModel.createReusableExecution(request, measure, loopTimeoutDuration);
const auto fn = [&request, measure, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}

View File

@@ -39,7 +39,9 @@ class MockDevice final : public IDevice {
MOCK_METHOD(GeneralResult<SharedPreparedModel>, prepareModel,
(const Model& model, ExecutionPreference preference, Priority priority,
OptionalTimePoint deadline, const std::vector<SharedHandle>& modelCache,
const std::vector<SharedHandle>& dataCache, const CacheToken& token),
const std::vector<SharedHandle>& dataCache, const CacheToken& token,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD(GeneralResult<SharedPreparedModel>, prepareModelFromCache,
(OptionalTimePoint deadline, const std::vector<SharedHandle>& modelCache,

View File

@@ -27,17 +27,23 @@ class MockPreparedModel final : public IPreparedModel {
public:
MOCK_METHOD((ExecutionResult<std::pair<std::vector<OutputShape>, Timing>>), execute,
(const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
const OptionalDuration& loopTimeoutDuration),
const OptionalDuration& loopTimeoutDuration,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD((GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>>), executeFenced,
(const Request& request, const std::vector<SyncFence>& waitFor,
MeasureTiming measure, const OptionalTimePoint& deadline,
const OptionalDuration& loopTimeoutDuration,
const OptionalDuration& timeoutDurationAfterFence),
const OptionalDuration& timeoutDurationAfterFence,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD((GeneralResult<SharedExecution>), createReusableExecution,
(const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration),
(const Request& request, MeasureTiming measure,
const OptionalDuration& loopTimeoutDuration,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD(GeneralResult<SharedBurst>, configureExecutionBurst, (), (const, override));
MOCK_METHOD(std::any, getUnderlyingResource, (), (const, override));

View File

@@ -309,12 +309,12 @@ TEST(ResilientDeviceTest, prepareModel) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(mockPreparedModel));
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -324,12 +324,12 @@ TEST(ResilientDeviceTest, prepareModel) {
TEST(ResilientDeviceTest, prepareModelError) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -339,13 +339,13 @@ TEST(ResilientDeviceTest, prepareModelError) {
TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -355,18 +355,18 @@ TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
TEST(ResilientDeviceTest, prepareModelDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
const auto recoveredMockDevice = createConfiguredMockDevice();
const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(mockPreparedModel));
EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -679,7 +679,7 @@ TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModel) {
device->recover(mockDevice.get(), /*blocking=*/false);
// run test
auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())

View File

@@ -104,12 +104,12 @@ TEST(ResilientPreparedModelTest, getPreparedModel) {
TEST(ResilientPreparedModelTest, execute) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoExecutionError));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -119,10 +119,12 @@ TEST(ResilientPreparedModelTest, execute) {
TEST(ResilientPreparedModelTest, executeError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -132,12 +134,12 @@ TEST(ResilientPreparedModelTest, executeError) {
TEST(ResilientPreparedModelTest, executeDeadObjectFailedRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
constexpr auto ret = [] { return nn::error(nn::ErrorStatus::GENERAL_FAILURE); };
EXPECT_CALL(*mockPreparedModelFactory, Call()).Times(1).WillOnce(ret);
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -147,9 +149,9 @@ TEST(ResilientPreparedModelTest, executeDeadObjectFailedRecovery) {
TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
const auto recoveredMockPreparedModel = createConfiguredMockPreparedModel();
EXPECT_CALL(*recoveredMockPreparedModel, execute(_, _, _, _))
EXPECT_CALL(*recoveredMockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoExecutionError));
EXPECT_CALL(*mockPreparedModelFactory, Call())
@@ -157,7 +159,7 @@ TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
.WillOnce(Return(recoveredMockPreparedModel));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -167,12 +169,12 @@ TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
TEST(ResilientPreparedModelTest, executeFenced) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoFencedExecutionError));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -182,12 +184,12 @@ TEST(ResilientPreparedModelTest, executeFenced) {
TEST(ResilientPreparedModelTest, executeFencedError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -197,13 +199,13 @@ TEST(ResilientPreparedModelTest, executeFencedError) {
TEST(ResilientPreparedModelTest, executeFencedDeadObjectFailedRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModelFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -213,11 +215,11 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectFailedRecovery) {
TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
const auto recoveredMockPreparedModel = createConfiguredMockPreparedModel();
EXPECT_CALL(*recoveredMockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*recoveredMockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoFencedExecutionError));
EXPECT_CALL(*mockPreparedModelFactory, Call())
@@ -225,7 +227,7 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
.WillOnce(Return(recoveredMockPreparedModel));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -235,12 +237,12 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
TEST(ResilientPreparedModelTest, createReusableExecution) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoCreateReusableExecutionError));
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -250,12 +252,12 @@ TEST(ResilientPreparedModelTest, createReusableExecution) {
TEST(ResilientPreparedModelTest, createReusableExecutionError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());