Merge "HAL interface for compilation and execution hints"

This commit is contained in:
Miao Wang
2022-01-20 15:43:58 +00:00
committed by Android (Google) Code Review
90 changed files with 1913 additions and 473 deletions

View File

@@ -45,12 +45,15 @@ class Burst final : public nn::IBurst {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
private:
const nn::SharedPreparedModel kPreparedModel;

View File

@@ -65,8 +65,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -49,18 +49,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -50,15 +50,20 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& /*memory*/)
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
}
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
}
} // namespace android::hardware::neuralnetworks::V1_0::utils

View File

@@ -143,7 +143,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference /*preference*/, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -59,7 +59,9 @@ PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/, sp<V1_0::IPreparedMo
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -94,19 +96,22 @@ PreparedModel::executeInternal(const V1_0::Request& request,
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& /*request*/,
const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
PreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "IPreparedModel::executeFenced is not supported on 1.0 HAL service";
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -380,7 +380,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -399,7 +399,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -417,7 +417,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -435,7 +435,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -452,7 +452,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -469,7 +469,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -488,7 +488,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -121,7 +121,7 @@ TEST(PreparedModelTest, execute) {
.WillOnce(Invoke(makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -138,7 +138,7 @@ TEST(PreparedModelTest, executeLaunchError) {
V1_0::ErrorStatus::GENERAL_FAILURE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -155,7 +155,7 @@ TEST(PreparedModelTest, executeReturnError) {
makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -171,7 +171,7 @@ TEST(PreparedModelTest, executeTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -187,7 +187,7 @@ TEST(PreparedModelTest, executeDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -205,7 +205,7 @@ TEST(PreparedModelTest, executeCrash) {
EXPECT_CALL(*mockPreparedModel, execute(_, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -218,7 +218,7 @@ TEST(PreparedModelTest, executeFencedNotSupported) {
const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -235,7 +235,7 @@ TEST(PreparedModelTest, reusableExecute) {
.WillRepeatedly(Invoke(makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -258,7 +258,7 @@ TEST(PreparedModelTest, reusableExecuteLaunchError) {
V1_0::ErrorStatus::GENERAL_FAILURE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -279,7 +279,7 @@ TEST(PreparedModelTest, reusableExecuteReturnError) {
makeExecute(V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -299,7 +299,7 @@ TEST(PreparedModelTest, reusableExecuteTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -319,7 +319,7 @@ TEST(PreparedModelTest, reusableExecuteDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -341,7 +341,7 @@ TEST(PreparedModelTest, reusableExecuteCrash) {
EXPECT_CALL(*mockPreparedModel, execute(_, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -358,7 +358,7 @@ TEST(PreparedModelTest, reusableExecuteFencedNotSupported) {
const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -64,8 +64,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -143,7 +143,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -390,7 +390,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -409,7 +409,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -427,7 +427,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -445,7 +445,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -462,7 +462,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -479,7 +479,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -498,7 +498,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -170,13 +170,16 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
// See IBurst::execute for information on this method.
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// See IBurst::createReusableExecution for information on this method.
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// If fallback is not nullptr, this method will invoke the fallback function to try another
// execution path if the packet could not be sent. Otherwise, failing to send the packet will

View File

@@ -37,7 +37,7 @@ GeneralResult<Operand> unvalidatedConvert(const hal::V1_2::Operand& operand);
GeneralResult<Operand::ExtraParams> unvalidatedConvert(
const hal::V1_2::Operand::ExtraParams& extraParams);
GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model);
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
GeneralResult<OutputShape> unvalidatedConvert(const hal::V1_2::OutputShape& outputShape);
GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& measureTiming);
@@ -78,7 +78,7 @@ nn::GeneralResult<Operand::ExtraParams> unvalidatedConvert(
const nn::Operand::ExtraParams& extraParams);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix);
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix);
nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape);
nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing);

View File

@@ -83,8 +83,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -49,18 +49,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -305,8 +305,9 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) cons
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// This is the first point when we know an execution is occurring, so begin to collect
// systraces. Note that the first point we can begin collecting systraces in
// ExecutionBurstServer is when the RequestChannelReceiver realizes there is data in the FMQ, so
@@ -317,7 +318,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// fall back to another execution path
if (!compliantVersion(request).ok()) {
// fallback to another execution path if the packet could not be sent
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, {}, {});
}
// ensure that request is ready for IPC
@@ -346,7 +347,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// send request packet
const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
const auto fallback = [this, &request, measure, &deadline, &loopTimeoutDuration] {
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration, {}, {});
};
return executeInternal(requestPacket, relocation, fallback);
}
@@ -354,14 +355,17 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
// See IBurst::createReusableExecution for information on this method.
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "Burst::createReusableExecution");
// if the request is valid but of a higher version than what's supported in burst execution,
// fall back to another execution path
if (!compliantVersion(request).ok()) {
// fallback to another execution path if the packet could not be sent
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration, {},
{});
}
// ensure that request is ready for IPC

View File

@@ -212,9 +212,9 @@ GeneralResult<Model> unvalidatedConvert(const hal::V1_2::Model& model) {
};
}
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const hal::V1_2::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
return ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,
};
@@ -495,7 +495,7 @@ nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
}
nn::GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,

View File

@@ -236,7 +236,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority /*priority*/,
nn::OptionalTimePoint /*deadline*/, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -91,7 +91,9 @@ PreparedModel::executeAsynchronously(const V1_0::Request& request, MeasureTiming
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -123,19 +125,22 @@ PreparedModel::executeInternal(const V1_0::Request& request, MeasureTiming measu
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& /*request*/,
const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
PreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "IPreparedModel::executeFenced is not supported on 1.2 HAL service";
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -636,7 +636,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -655,7 +655,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -673,7 +673,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -691,7 +691,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -708,7 +708,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -725,7 +725,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -746,7 +746,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -154,7 +154,7 @@ TEST(PreparedModelTest, executeSync) {
.WillOnce(Invoke(makeExecuteSynchronously(V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -172,7 +172,7 @@ TEST(PreparedModelTest, executeSyncError) {
makeExecuteSynchronously(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -189,7 +189,7 @@ TEST(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -206,7 +206,7 @@ TEST(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -224,7 +224,7 @@ TEST(PreparedModelTest, executeAsync) {
V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -243,7 +243,7 @@ TEST(PreparedModelTest, executeAsyncLaunchError) {
kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -261,7 +261,7 @@ TEST(PreparedModelTest, executeAsyncReturnError) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -278,7 +278,7 @@ TEST(PreparedModelTest, executeAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -295,7 +295,7 @@ TEST(PreparedModelTest, executeAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -314,7 +314,7 @@ TEST(PreparedModelTest, executeAsyncCrash) {
EXPECT_CALL(*mockPreparedModel, execute_1_2(_, _, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -328,7 +328,7 @@ TEST(PreparedModelTest, executeFencedNotSupported) {
PreparedModel::create(mockPreparedModel, /*executeSynchronously=*/true).value();
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -347,7 +347,7 @@ TEST(PreparedModelTest, reusableExecuteSync) {
Invoke(makeExecuteSynchronously(V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -371,7 +371,7 @@ TEST(PreparedModelTest, reusableExecuteSyncError) {
makeExecuteSynchronously(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -392,7 +392,7 @@ TEST(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -413,7 +413,7 @@ TEST(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -436,7 +436,7 @@ TEST(PreparedModelTest, reusableExecuteAsync) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -461,7 +461,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncLaunchError) {
kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -483,7 +483,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncReturnError) {
V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -504,7 +504,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -525,7 +525,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -548,7 +548,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncCrash) {
EXPECT_CALL(*mockPreparedModel, execute_1_2(_, _, _)).Times(1).WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -566,7 +566,7 @@ TEST(PreparedModelTest, reusableExecuteFencedNotSupported) {
PreparedModel::create(mockPreparedModel, /*executeSynchronously=*/true).value();
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -66,8 +66,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -48,18 +48,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -396,7 +396,7 @@ nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
}
nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
}

View File

@@ -187,7 +187,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =

View File

@@ -135,8 +135,9 @@ PreparedModel::executeAsynchronously(const Request& request, V1_2::MeasureTiming
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -174,10 +175,13 @@ PreparedModel::executeInternal(const Request& request, V1_2::MeasureTiming measu
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
PreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -230,7 +234,9 @@ PreparedModel::executeFencedInternal(const Request& request, const hidl_vec<hidl
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;

View File

@@ -658,7 +658,7 @@ TEST(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -677,7 +677,7 @@ TEST(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -695,7 +695,7 @@ TEST(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -713,7 +713,7 @@ TEST(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -730,7 +730,7 @@ TEST(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -747,7 +747,7 @@ TEST(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -768,7 +768,7 @@ TEST(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -182,7 +182,7 @@ TEST(PreparedModelTest, executeSync) {
.WillOnce(Invoke(makeExecuteSynchronously(V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -200,7 +200,7 @@ TEST(PreparedModelTest, executeSyncError) {
makeExecuteSynchronously(V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -217,7 +217,7 @@ TEST(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -234,7 +234,7 @@ TEST(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -252,7 +252,7 @@ TEST(PreparedModelTest, executeAsync) {
V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -271,7 +271,7 @@ TEST(PreparedModelTest, executeAsyncLaunchError) {
kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -289,7 +289,7 @@ TEST(PreparedModelTest, executeAsyncReturnError) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -306,7 +306,7 @@ TEST(PreparedModelTest, executeAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -323,7 +323,7 @@ TEST(PreparedModelTest, executeAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -344,7 +344,7 @@ TEST(PreparedModelTest, executeAsyncCrash) {
.WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -366,7 +366,7 @@ TEST(PreparedModelTest, executeFenced) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -396,7 +396,7 @@ TEST(PreparedModelTest, executeFencedCallbackError) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -422,7 +422,7 @@ TEST(PreparedModelTest, executeFencedError) {
makeExecuteFencedReturn(V1_3::ErrorStatus::GENERAL_FAILURE, {}, nullptr)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -439,7 +439,7 @@ TEST(PreparedModelTest, executeFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -456,7 +456,7 @@ TEST(PreparedModelTest, executeFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -475,7 +475,7 @@ TEST(PreparedModelTest, reusableExecuteSync) {
Invoke(makeExecuteSynchronously(V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -499,7 +499,7 @@ TEST(PreparedModelTest, reusableExecuteSyncError) {
makeExecuteSynchronously(V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -520,7 +520,7 @@ TEST(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -541,7 +541,7 @@ TEST(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -564,7 +564,7 @@ TEST(PreparedModelTest, reusableExecuteAsync) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::NONE, {}, kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -589,7 +589,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncLaunchError) {
kNoTiming)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -611,7 +611,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncReturnError) {
V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -628,7 +628,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -649,7 +649,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -674,7 +674,7 @@ TEST(PreparedModelTest, reusableExecuteAsyncCrash) {
.WillOnce(InvokeWithoutArgs(ret));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -702,7 +702,7 @@ TEST(PreparedModelTest, reusableExecuteFenced) {
Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -738,7 +738,7 @@ TEST(PreparedModelTest, reusableExecuteFencedCallbackError) {
.WillOnce(Invoke(makeExecuteFencedReturn(V1_3::ErrorStatus::NONE, {}, mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -768,7 +768,7 @@ TEST(PreparedModelTest, reusableExecuteFencedError) {
makeExecuteFencedReturn(V1_3::ErrorStatus::GENERAL_FAILURE, {}, nullptr)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -789,7 +789,7 @@ TEST(PreparedModelTest, reusableExecuteFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -810,7 +810,7 @@ TEST(PreparedModelTest, reusableExecuteFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);

View File

@@ -0,0 +1,41 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable ExecutionConfig {
boolean measureTiming;
long loopTimeoutDurationNs;
android.hardware.neuralnetworks.TokenValuePair[] executionHints;
android.hardware.neuralnetworks.ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -36,4 +36,5 @@ package android.hardware.neuralnetworks;
interface IBurst {
android.hardware.neuralnetworks.ExecutionResult executeSynchronously(in android.hardware.neuralnetworks.Request request, in long[] memoryIdentifierTokens, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs);
void releaseMemoryResource(in long memoryIdentifierToken);
android.hardware.neuralnetworks.ExecutionResult executeSynchronouslyWithConfig(in android.hardware.neuralnetworks.Request request, in long[] memoryIdentifierTokens, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs);
}

View File

@@ -43,6 +43,7 @@ interface IDevice {
String getVersionString();
void prepareModel(in android.hardware.neuralnetworks.Model model, in android.hardware.neuralnetworks.ExecutionPreference preference, in android.hardware.neuralnetworks.Priority priority, in long deadlineNs, in ParcelFileDescriptor[] modelCache, in ParcelFileDescriptor[] dataCache, in byte[] token, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
void prepareModelFromCache(in long deadlineNs, in ParcelFileDescriptor[] modelCache, in ParcelFileDescriptor[] dataCache, in byte[] token, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
void prepareModelWithConfig(in android.hardware.neuralnetworks.Model model, in android.hardware.neuralnetworks.PrepareModelConfig config, in android.hardware.neuralnetworks.IPreparedModelCallback callback);
const int BYTE_SIZE_OF_CACHE_TOKEN = 32;
const int MAX_NUMBER_OF_CACHE_FILES = 32;
const int EXTENSION_TYPE_HIGH_BITS_PREFIX = 15;

View File

@@ -37,7 +37,9 @@ interface IPreparedModel {
android.hardware.neuralnetworks.ExecutionResult executeSynchronously(in android.hardware.neuralnetworks.Request request, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs);
android.hardware.neuralnetworks.FencedExecutionResult executeFenced(in android.hardware.neuralnetworks.Request request, in ParcelFileDescriptor[] waitFor, in boolean measureTiming, in long deadlineNs, in long loopTimeoutDurationNs, in long durationNs);
android.hardware.neuralnetworks.IBurst configureExecutionBurst();
android.hardware.neuralnetworks.IExecution createReusableExecution(in android.hardware.neuralnetworks.Request request, in boolean measureTiming, in long loopTimeoutDurationNs);
android.hardware.neuralnetworks.IExecution createReusableExecution(in android.hardware.neuralnetworks.Request request, in android.hardware.neuralnetworks.ExecutionConfig config);
android.hardware.neuralnetworks.ExecutionResult executeSynchronouslyWithConfig(in android.hardware.neuralnetworks.Request request, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs);
android.hardware.neuralnetworks.FencedExecutionResult executeFencedWithConfig(in android.hardware.neuralnetworks.Request request, in ParcelFileDescriptor[] waitFor, in android.hardware.neuralnetworks.ExecutionConfig config, in long deadlineNs, in long durationNs);
const long DEFAULT_LOOP_TIMEOUT_DURATION_NS = 2000000000;
const long MAXIMUM_LOOP_TIMEOUT_DURATION_NS = 15000000000;
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable PrepareModelConfig {
android.hardware.neuralnetworks.ExecutionPreference preference;
android.hardware.neuralnetworks.Priority priority;
long deadlineNs;
ParcelFileDescriptor[] modelCache;
ParcelFileDescriptor[] dataCache;
byte[] cacheToken;
android.hardware.neuralnetworks.TokenValuePair[] compilationHints;
android.hardware.neuralnetworks.ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////
// THIS FILE IS IMMUTABLE. DO NOT EDIT IN ANY CASE. //
///////////////////////////////////////////////////////////////////////////////
// This file is a snapshot of an AIDL file. Do not edit it manually. There are
// two cases:
// 1). this is a frozen version file - do not edit this in any case.
// 2). this is a 'current' file. If you make a backwards compatible change to
// the interface (from the latest frozen version), the build system will
// prompt you to update this file with `m <name>-update-api`.
//
// You must not make a backward incompatible change to any AIDL file built
// with the aidl_interface module type with versions property set. The module
// type is used to build AIDL files in a way that they can be used across
// independently updatable components of the system. If a device is shipped
// with such a backward incompatible change, it has a high risk of breaking
// later when a module using the interface is updated, e.g., Mainline modules.
package android.hardware.neuralnetworks;
@VintfStability
parcelable TokenValuePair {
int token;
byte[] value;
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ExtensionNameAndPrefix;
import android.hardware.neuralnetworks.TokenValuePair;
/**
* A type that is used to represent all configuration related to
* an Execution.
*/
@VintfStability
parcelable ExecutionConfig {
/**
* Specifies whether or not to measure duration of the execution.
* For {@link IPreparedModel::executeSynchronouslyWithConfig}, the duration runs from the time
* the driver sees the corresponding call to the execute function to the time the driver returns
* from the function. For {@link IPreparedModel::executeFencedWithConfig}, please refer to
* {@link IPreparedModelCallback} for details.
*/
boolean measureTiming;
/**
* The maximum amount of time in nanoseconds that should be spent
* executing a {@link OperationType::WHILE} operation. If a loop
* condition model does not output false within this duration,
* the execution must be aborted. If -1 is provided, the maximum
* amount of time is {@link DEFAULT_LOOP_TIMEOUT_DURATION_NS}.
* Other negative values are invalid. When provided, the duration
* must not exceed {@link MAXIMUM_LOOP_TIMEOUT_DURATION_NS}.
*/
long loopTimeoutDurationNs;
/**
* A vector of token / value pairs represent vendor specific
* execution hints or metadata. The provided TokenValuePairs must not
* contain the same token twice. The driver must validate the
* data and ignore invalid hints. It is up to the driver to
* decide whether to respect the provided hints or not.
*/
TokenValuePair[] executionHints;
/**
* The mapping between extension names and prefixes of token values.
* The driver must ignore the corresponding execution hint, if
* the extension is not supported.
*/
ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -20,6 +20,10 @@ import android.hardware.neuralnetworks.ExtensionOperandTypeInformation;
/**
* Information about an extension.
*
* The extension can provide zero or more operation types (which are not enumerated), zero or more
* operand types (which are enumerated in {@link Extension::operandTypes}, and compilation and
* execution hints (which are not enumerated).
*/
@VintfStability
parcelable Extension {

View File

@@ -17,7 +17,8 @@
package android.hardware.neuralnetworks;
/**
* The mapping between extension names and prefixes of operand and operation type values.
* The mapping between extension names and prefixes of values like operand and operation type, and
* token in {@link TokenValuePair}.
*
* An operand or operation whose numeric type value is above {@link IDevice::OPERAND_TYPE_BASE_MAX}
* or {@link IDevice::OPERATION_TYPE_BASE_MAX} respectively should be interpreted as an extension

View File

@@ -17,6 +17,7 @@
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ErrorStatus;
import android.hardware.neuralnetworks.ExecutionConfig;
import android.hardware.neuralnetworks.ExecutionResult;
import android.hardware.neuralnetworks.Request;
@@ -68,6 +69,8 @@ interface IBurst {
*
* Only a single execution on a given burst object may be active at any time.
*
* Also see {@link IBurst::executeSynchronouslyWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param memoryIdentifierTokens A list of tokens where each token is a non-negative number
@@ -117,4 +120,13 @@ interface IBurst {
* - INVALID_ARGUMENT if one of the input arguments is invalid
*/
void releaseMemoryResource(in long memoryIdentifierToken);
/**
* For detailed specification, please refer to {@link IBurst::executeSynchronously}. The
* difference between the two methods is that executeSynchronouslyWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeSynchronously.
*/
ExecutionResult executeSynchronouslyWithConfig(in Request request,
in long[] memoryIdentifierTokens, in ExecutionConfig config, in long deadlineNs);
}

View File

@@ -28,6 +28,7 @@ import android.hardware.neuralnetworks.IPreparedModelCallback;
import android.hardware.neuralnetworks.IPreparedModelParcel;
import android.hardware.neuralnetworks.Model;
import android.hardware.neuralnetworks.NumberOfCacheFiles;
import android.hardware.neuralnetworks.PrepareModelConfig;
import android.hardware.neuralnetworks.Priority;
/**
@@ -148,7 +149,7 @@ interface IDevice {
*
* If the device reports that caching is not supported, the user may avoid calling
* IDevice::prepareModelFromCache or providing cache file descriptors to
* IDevice::prepareModel.
* IDevice::prepareModel or IDevice::prepareModelWithConfig.
*
* @return NumberOfCacheFiles structure indicating how many files for model and data cache the
* driver needs to cache a single prepared model. It must be less than or equal to
@@ -302,6 +303,8 @@ interface IDevice {
*
* Multiple threads may call prepareModel on the same model concurrently.
*
* Also see {@link IDevice::prepareModelWithConfig}.
*
* @param model The model to be prepared for execution.
* @param preference Indicates the intended execution behavior of a prepared model.
* @param priority The priority of the prepared model relative to other prepared models owned by
@@ -403,17 +406,17 @@ interface IDevice {
* @param modelCache A vector of file descriptors for the security-sensitive cache. The length
* of the vector must match the numModelCache returned from
* getNumberOfCacheFilesNeeded. The cache file descriptors will be provided in
* the same order as with prepareModel.
* the same order as with prepareModel or prepareModelWithConfig.
* @param dataCache A vector of file descriptors for the constants' cache. The length of the
* vector must match the numDataCache returned from
* getNumberOfCacheFilesNeeded. The cache file descriptors will be provided in
* the same order as with prepareModel.
* the same order as with prepareModel or prepareModelWithConfig.
* @param token A caching token of length BYTE_SIZE_OF_CACHE_TOKEN identifying the prepared
* model. It is the same token provided when saving the cache files with
* prepareModel. Tokens should be chosen to have a low rate of collision for a
* particular application. The driver cannot detect a collision; a collision will
* result in a failed execution or in a successful execution that produces
* incorrect output values.
* prepareModel or prepareModelWithConfig. Tokens should be chosen to have a low
* rate of collision for a particular application. The driver cannot detect a
* collision; a collision will result in a failed execution or in a successful
* execution that produces incorrect output values.
* @param callback A callback object used to return the error status of preparing the model for
* execution and the prepared model if successful, nullptr otherwise. The
* callback object's notify function must be called exactly once, even if the
@@ -429,4 +432,28 @@ interface IDevice {
void prepareModelFromCache(in long deadlineNs, in ParcelFileDescriptor[] modelCache,
in ParcelFileDescriptor[] dataCache, in byte[] token,
in IPreparedModelCallback callback);
/**
* For detailed specification, please refer to {@link IDevice::prepareModel}. The only
* difference between the two methods is that prepareModelWithConfig takes {@link
* PrepareModelConfig} instead of standalone configuration parameters, which allows vendor
* specific compilation metadata to be passed.
*
* @param model The model to be prepared for execution.
* @param config Configuration parameters to prepare the model.
* @param callback A callback object used to return the error status of preparing the model for
* execution and the prepared model if successful, nullptr otherwise. The
* callback object's notify function must be called exactly once, even if the
* model could not be prepared.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
* - DEVICE_UNAVAILABLE if driver is offline or busy
* - GENERAL_FAILURE if there is an unspecified error
* - INVALID_ARGUMENT if one of the input arguments related to preparing the model is
* invalid
* - MISSED_DEADLINE_* if the preparation is aborted because the model cannot be prepared by
* the deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
void prepareModelWithConfig(
in Model model, in PrepareModelConfig config, in IPreparedModelCallback callback);
}

View File

@@ -18,6 +18,7 @@ package android.hardware.neuralnetworks;
import android.hardware.common.NativeHandle;
import android.hardware.neuralnetworks.ErrorStatus;
import android.hardware.neuralnetworks.ExecutionConfig;
import android.hardware.neuralnetworks.ExecutionResult;
import android.hardware.neuralnetworks.FencedExecutionResult;
import android.hardware.neuralnetworks.IBurst;
@@ -68,6 +69,8 @@ interface IPreparedModel {
* Any number of calls to the execute* functions, in any combination, may be made concurrently,
* even on the same IPreparedModel object.
*
* Also see {@link IPreparedModel::executeSynchronouslyWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param measure Specifies whether or not to measure duration of the execution. The duration
@@ -134,6 +137,8 @@ interface IPreparedModel {
* Any number of calls to the execute* functions, in any combination, may be made concurrently,
* even on the same IPreparedModel object.
*
* Also see {@link IPreparedModel::executeFencedWithConfig}.
*
* @param request The input and output information on which the prepared model is to be
* executed. The outputs in the request must have fully specified dimensions.
* @param waitFor A vector of sync fence file descriptors. Execution must not start until all
@@ -201,15 +206,7 @@ interface IPreparedModel {
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param measure Specifies whether or not to measure duration of the execution.
* @param loopTimeoutDurationNs The maximum amount of time in nanoseconds that should be spent
* executing a {@link OperationType::WHILE} operation. If a loop
* condition model does not output false within this duration, the
* computation performed on the returned reusable execution object
* must be aborted. If -1 is provided, the maximum amount
* of time is {@link DEFAULT_LOOP_TIMEOUT_DURATION_NS}. Other
* negative values are invalid. When provided, the duration must
* not exceed {@link MAXIMUM_LOOP_TIMEOUT_DURATION_NS}.
* @param config Specifies the execution configuration parameters.
* @return execution An IExecution object representing a reusable execution that has been
* specialized for a fixed request.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
@@ -218,6 +215,64 @@ interface IPreparedModel {
* - INVALID_ARGUMENT if one of the input arguments is invalid
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
IExecution createReusableExecution(
in Request request, in boolean measureTiming, in long loopTimeoutDurationNs);
IExecution createReusableExecution(in Request request, in ExecutionConfig config);
/**
* For detailed specification, please refer to {@link IPreparedModel::executeSynchronously}. The
* difference between the two methods is that executeSynchronouslyWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeSynchronously.
*
* @param request The input and output information on which the prepared model is to be
* executed.
* @param config Specifies the execution configuration parameters.
* @param deadlineNs The time by which the execution is expected to complete. The time is
* measured in nanoseconds since boot (as from clock_gettime(CLOCK_BOOTTIME,
* &ts) or ::android::base::boot_clock). If the execution cannot be finished
* by the deadline, the execution may be aborted. Passing -1 means the
* deadline is omitted. Other negative valueggs are invalid.
* @return ExecutionResult parcelable, containing the status of the execution, output shapes and
* timing information.
* - MISSED_DEADLINE_* if the execution is aborted because it cannot be completed by the
* deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
ExecutionResult executeSynchronouslyWithConfig(
in Request request, in ExecutionConfig config, in long deadlineNs);
/**
* For detailed specification, please refer to {@link IPreparedModel::executeFenced}. The
* difference between the two methods is that executeFencedWithConfig takes {@link
* ExecutionConfig} instead of a list of configuration parameters, and ExecutionConfig contains
* more configuration parameters than are passed to executeFenced.
*
* @param request The input and output information on which the prepared model is to be
* executed. The outputs in the request must have fully specified dimensions.
* @param waitFor A vector of sync fence file descriptors. Execution must not start until all
* sync fences have been signaled.
* @param config Specifies the execution configuration parameters.
* @param deadlineNs The time by which the execution is expected to complete. The time is
* measured in nanoseconds since boot (as from clock_gettime(CLOCK_BOOTTIME,
* &ts) or ::android::base::boot_clock). If the execution cannot be finished
* by the deadline, the execution may be aborted. Passing -1 means the
* deadline is omitted. Other negative values are invalid.
* @param durationNs The length of time in nanoseconds within which the execution is expected to
* complete after all sync fences in waitFor are signaled. If the execution
* cannot be finished within the duration, the execution may be aborted.
* Passing -1 means the duration is omitted. Other negative values are
* invalid.
* @return The FencedExecutionResult parcelable, containing IFencedExecutionCallback and the
* sync fence.
* @throws ServiceSpecificException with one of the following ErrorStatus values:
* - DEVICE_UNAVAILABLE if driver is offline or busy
* - GENERAL_FAILURE if there is an unspecified error
* - INVALID_ARGUMENT if one of the input arguments is invalid, including fences in error
* states.
* - MISSED_DEADLINE_* if the execution is aborted because it cannot be completed by the
* deadline
* - RESOURCE_EXHAUSTED_* if the task was aborted by the driver
*/
FencedExecutionResult executeFencedWithConfig(in Request request,
in ParcelFileDescriptor[] waitFor, in ExecutionConfig config, in long deadlineNs,
in long durationNs);
}

View File

@@ -0,0 +1,95 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
import android.hardware.neuralnetworks.ExecutionPreference;
import android.hardware.neuralnetworks.ExtensionNameAndPrefix;
import android.hardware.neuralnetworks.Priority;
import android.hardware.neuralnetworks.TokenValuePair;
/**
* A type that is used to represent all configuration needed to
* prepare a model.
*/
@VintfStability
parcelable PrepareModelConfig {
/**
* Indicates the intended execution behavior of a prepared model.
*/
ExecutionPreference preference;
/**
* The priority of the prepared model relative to other prepared
* models owned by the client.
*/
Priority priority;
/**
* The time by which the model is expected to be prepared. The
* time is measured in nanoseconds since boot (as from
* clock_gettime(CLOCK_BOOTTIME, &ts) or
* ::android::base::boot_clock). If the model cannot be prepared
* by the deadline, the preparation may be aborted. Passing -1
* means the deadline is omitted. Other negative values are
* invalid.
*/
long deadlineNs;
/**
* A vector of file descriptors for the security-sensitive cache.
* The length of the vector must either be 0 indicating that
* caching information is not provided, or match the
* numModelCache returned from IDevice::getNumberOfCacheFilesNeeded. The
* cache file descriptors will be provided in the same order when
* retrieving the preparedModel from cache files with
* IDevice::prepareModelFromCache.
*/
ParcelFileDescriptor[] modelCache;
/**
* A vector of file descriptors for the constants' cache. The
* length of the vector must either be 0 indicating that caching
* information is not provided, or match the numDataCache
* returned from IDevice::getNumberOfCacheFilesNeeded. The cache file
* descriptors will be provided in the same order when retrieving
* the preparedModel from cache files with IDevice::prepareModelFromCache.
*/
ParcelFileDescriptor[] dataCache;
/**
* A caching token of length IDevice::BYTE_SIZE_OF_CACHE_TOKEN identifying
* the prepared model. The same token will be provided when
* retrieving the prepared model from the cache files with
* IDevice::prepareModelFromCache. Tokens should be chosen to have a low
* rate of collision for a particular application. The driver
* cannot detect a collision; a collision will result in a failed
* execution or in a successful execution that produces incorrect
* output values. If both modelCache and dataCache are empty
* indicating that caching information is not provided, this
* token must be ignored.
*/
byte[] cacheToken;
/**
* A vector of token / value pairs represent vendor specific
* compilation hints or metadata. The provided TokenValuePairs must not
* contain the same token twice. The driver must validate the
* data and ignore invalid hints. It is up to the driver to
* decide whether to respect the provided hints or not.
*/
TokenValuePair[] compilationHints;
/**
* The mapping between extension names and prefixes of token values.
* The driver must ignore the corresponding compilation hint, if
* the extension is not supported.
*/
ExtensionNameAndPrefix[] extensionNameToPrefix;
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.hardware.neuralnetworks;
/**
* A type that is used to represent a token / byte array data pair.
*/
@VintfStability
parcelable TokenValuePair {
/**
* A 32bit integer token. The token is created by combining the
* extension prefix and enum defined within the extension.
* The low {@link IDevice::EXTENSION_TYPE_LOW_BITS_TYPE} bits of the value
* correspond to the hint within the extension and the high
* {@link IDevice::EXTENSION_TYPE_HIGH_BITS_PREFIX} bits encode the "prefix", which maps
* uniquely to the extension name. The sign bit is always 0.
*
* For example, if a token value is 0x7AAA000B and the corresponding
* {@link ExtensionNameAndPrefix} contains an entry with prefix=0x7AAA and
* name="vendor.test.test_extension", then the token should be interpreted as the hint
* 0x000B of the extension named vendor.test.test_extension.
*/
int token;
/**
* A byte array containing the raw data.
*/
byte[] value;
}

View File

@@ -86,10 +86,12 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
GUARDED_BY(mMutex);
};
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Burst>> create(
std::shared_ptr<aidl_hal::IBurst> burst);
std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel);
Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst);
Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst,
nn::Version featureLevel);
// See IBurst::cacheMemory for information.
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
@@ -97,23 +99,29 @@ class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst
// See IBurst::execute for information.
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// See IBurst::createReusableExecution for information.
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const aidl_hal::Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
bool measure, int64_t deadline, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:
mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
const std::shared_ptr<aidl_hal::IBurst> kBurst;
const std::shared_ptr<MemoryCache> kMemoryCache;
const nn::Version kFeatureLevel;
};
} // namespace aidl::android::hardware::neuralnetworks::utils

View File

@@ -46,6 +46,10 @@
#include <aidl/android/hardware/neuralnetworks/SymmPerChannelQuantParams.h>
#include <aidl/android/hardware/neuralnetworks/Timing.h>
#ifdef NN_AIDL_V4_OR_ABOVE
#include <aidl/android/hardware/neuralnetworks/TokenValuePair.h>
#endif // NN_AIDL_V4_OR_ABOVE
#include <android/binder_auto_utils.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
@@ -74,7 +78,7 @@ GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams);
GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation);
GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model);
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix);
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues);
GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph);
@@ -97,6 +101,10 @@ GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation);
GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor& handle);
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair);
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<Operation>> unvalidatedConvert(
const std::vector<aidl_hal::Operation>& operations);
@@ -116,6 +124,14 @@ GeneralResult<BufferDesc> convert(const aidl_hal::BufferDesc& bufferDesc);
GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix);
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<aidl_hal::TokenValuePair>& metaData);
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes);
GeneralResult<std::vector<SharedHandle>> convert(
@@ -152,7 +168,7 @@ nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgra
nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
const nn::Model::OperandValues& operandValues);
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix);
const nn::ExtensionNameAndPrefix& extensionNameToPrefix);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority);
nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
@@ -166,6 +182,10 @@ nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::Shared
nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities);
nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension);
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair);
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken);
nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc);
nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType);
@@ -190,6 +210,13 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences);
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions);
nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix);
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<nn::TokenValuePair>& metaData);
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec);

View File

@@ -42,6 +42,7 @@ class Device final : public nn::IDevice {
struct PrivateConstructorTag {};
public:
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Device>> create(
std::string name, std::shared_ptr<aidl_hal::IDevice> device, nn::Version featureLevel);
@@ -67,8 +68,9 @@ class Device final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -63,7 +63,9 @@
#ifdef NN_AIDL_V4_OR_ABOVE
#include <aidl/android/hardware/neuralnetworks/BnExecution.h>
#include <aidl/android/hardware/neuralnetworks/ExecutionConfig.h>
#include <aidl/android/hardware/neuralnetworks/IExecution.h>
#include <aidl/android/hardware/neuralnetworks/PrepareModelConfig.h>
#endif // NN_AIDL_V4_OR_ABOVE
namespace android::nn {

View File

@@ -53,6 +53,9 @@ class InvalidDevice : public BnDevice {
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelFromCache(
int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,

View File

@@ -40,6 +40,7 @@ class PreparedModel final : public nn::IPreparedModel,
struct PrivateConstructorTag {};
public:
// featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const PreparedModel>> create(
std::shared_ptr<aidl_hal::IPreparedModel> preparedModel, nn::Version featureLevel);
@@ -49,18 +50,23 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
@@ -68,6 +74,8 @@ class PreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const Request& request, bool measure, int64_t deadline, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,6 +83,8 @@ class PreparedModel final : public nn::IPreparedModel,
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
int64_t deadline, int64_t loopTimeoutDuration,
int64_t timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:

View File

@@ -43,12 +43,16 @@ class BurstExecution final : public nn::IExecution,
static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure,
int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
@@ -64,6 +68,8 @@ class BurstExecution final : public nn::IExecution,
const std::vector<int64_t> kMemoryIdentifierTokens;
const bool kMeasure;
const int64_t kLoopTimeoutDuration;
const std::vector<nn::TokenValuePair> kHints;
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix;
const hal::utils::RequestRelocation kRelocation;
const std::vector<Burst::OptionalCacheHold> kCacheHolds;
};
@@ -149,17 +155,20 @@ void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t i
}
nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
std::shared_ptr<aidl_hal::IBurst> burst) {
std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel) {
if (burst == nullptr) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "aidl_hal::utils::Burst::create must have non-null burst";
}
return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst), featureLevel);
}
Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
: kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst,
nn::Version featureLevel)
: kBurst(std::move(burst)),
kMemoryCache(std::make_shared<MemoryCache>(kBurst)),
kFeatureLevel(featureLevel) {
CHECK(kBurst != nullptr);
}
@@ -170,8 +179,9 @@ Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) cons
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -200,14 +210,14 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
memoryIdentifierTokens.push_back(-1);
}
CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
aidlLoopTimeoutDuration, relocation);
aidlLoopTimeoutDuration, hints, extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
int64_t deadline, int64_t loopTimeoutDuration,
int64_t deadline, int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
// Ensure that at most one execution is in flight at any given time.
const bool alreadyInFlight = mExecutionInFlight.test_and_set();
@@ -221,9 +231,21 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
}
ExecutionResult executionResult;
const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
deadline, loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kBurst->executeSynchronouslyWithConfig(
request, memoryIdentifierTokens,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
} else {
const auto ret =
kBurst->executeSynchronously(request, memoryIdentifierTokens, measure, deadline,
loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "execute failed";
}
if (!executionResult.outputSufficientSize) {
auto canonicalOutputShapes =
nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
@@ -241,7 +263,9 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -272,12 +296,15 @@ nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
std::move(memoryIdentifierTokens), aidlMeasure,
aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
aidlLoopTimeoutDuration, hints, extensionNameToPrefix,
std::move(relocation), std::move(holds));
}
nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds) {
if (burst == nullptr) {
@@ -286,13 +313,15 @@ nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
return std::make_shared<const BurstExecution>(
PrivateConstructorTag{}, std::move(burst), std::move(request),
std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
std::move(cacheHolds));
std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, hints,
extensionNameToPrefix, std::move(relocation), std::move(cacheHolds));
}
BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
Request request, std::vector<int64_t> memoryIdentifierTokens,
bool measure, int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds)
: kBurst(std::move(burst)),
@@ -300,6 +329,8 @@ BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<co
kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
kMeasure(measure),
kLoopTimeoutDuration(loopTimeoutDuration),
kHints(hints),
kExtensionNameToPrefix(extensionNameToPrefix),
kRelocation(std::move(relocation)),
kCacheHolds(std::move(cacheHolds)) {}
@@ -307,7 +338,8 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstEx
const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
kLoopTimeoutDuration, kRelocation);
kLoopTimeoutDuration, kHints, kExtensionNameToPrefix,
kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>

View File

@@ -302,9 +302,9 @@ GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subg
};
}
GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
return Model::ExtensionNameAndPrefix{
return ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,
};
@@ -506,6 +506,12 @@ GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor&
return std::make_shared<const Handle>(std::move(duplicatedFd));
}
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair) {
return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
}
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
return validatedConvert(capabilities);
}
@@ -562,6 +568,17 @@ GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extens
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
return validatedConvert(memories);
}
GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix) {
return unvalidatedConvert(extensionNameAndPrefix);
}
#ifdef NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<aidl_hal::TokenValuePair>& metaData) {
return validatedConvert(metaData);
}
#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes) {
@@ -942,7 +959,7 @@ nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
}
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
const nn::ExtensionNameAndPrefix& extensionNameToPrefix) {
return ExtensionNameAndPrefix{
.name = extensionNameToPrefix.name,
.prefix = extensionNameToPrefix.prefix,
@@ -1055,6 +1072,11 @@ nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension)
return Extension{.name = extension.name,
.operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes))};
}
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair) {
return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
}
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
return validatedConvert(cacheToken);
@@ -1134,6 +1156,17 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences) {
return validatedConvert(syncFences);
}
nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) {
return unvalidatedConvert(extensionNameToPrefix);
}
#ifdef NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<TokenValuePair>> convert(
const std::vector<nn::TokenValuePair>& metaData) {
return validatedConvert(metaData);
}
#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions) {
return validatedConvert(extensions);

View File

@@ -215,7 +215,9 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =
@@ -225,17 +227,28 @@ nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const auto aidlPreference = NN_TRY(convert(preference));
const auto aidlPriority = NN_TRY(convert(priority));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlModelCache = NN_TRY(convert(modelCache));
const auto aidlDataCache = NN_TRY(convert(dataCache));
auto aidlModelCache = NN_TRY(convert(modelCache));
auto aidlDataCache = NN_TRY(convert(dataCache));
const auto aidlToken = NN_TRY(convert(token));
const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>(kFeatureLevel);
const auto scoped = kDeathHandler.protectCallback(cb.get());
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kDevice->prepareModelWithConfig(
aidlModel,
{aidlPreference, aidlPriority, aidlDeadline, std::move(aidlModelCache),
std::move(aidlDataCache), aidlToken, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
cb);
HANDLE_ASTATUS(ret) << "prepareModel failed";
return cb->get();
}
const auto ret = kDevice->prepareModel(aidlModel, aidlPreference, aidlPriority, aidlDeadline,
aidlModelCache, aidlDataCache, aidlToken, cb);
HANDLE_ASTATUS(ret) << "prepareModel failed";
return cb->get();
}

View File

@@ -63,7 +63,7 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ExecutionWithCachedRequest::compute(const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kPreparedModel->executeInternal(kRequest, kMeasure, aidlDeadline, kLoopTimeoutDuration,
kRelocation);
{}, {}, kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -73,9 +73,9 @@ ExecutionWithCachedRequest::computeFenced(
const auto aidlWaitFor = NN_TRY(convert(waitFor));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
return kPreparedModel->executeFencedInternal(kRequest, aidlWaitFor, kMeasure, aidlDeadline,
kLoopTimeoutDuration,
aidlTimeoutDurationAfterFence, kRelocation);
return kPreparedModel->executeFencedInternal(
kRequest, aidlWaitFor, kMeasure, aidlDeadline, kLoopTimeoutDuration,
aidlTimeoutDurationAfterFence, {}, {}, kRelocation);
}
nn::GeneralResult<std::shared_ptr<const Execution>> Execution::create(

View File

@@ -167,6 +167,31 @@ ndk::ScopedAStatus InvalidDevice::prepareModel(
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus InvalidDevice::prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) {
if (!utils::valid(config.extensionNameToPrefix)) {
callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid extensionNameToPrefix");
}
for (const auto& hint : config.compilationHints) {
auto result = std::find_if(config.extensionNameToPrefix.begin(),
config.extensionNameToPrefix.end(),
[&hint](const ExtensionNameAndPrefix& extension) {
uint16_t prefix = static_cast<uint32_t>(hint.token) >>
IDevice::EXTENSION_TYPE_LOW_BITS_TYPE;
return prefix == extension.prefix;
});
if (result == config.extensionNameToPrefix.end()) {
callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
return toAStatus(ErrorStatus::INVALID_ARGUMENT,
"Invalid token for compilation hints: " + std::to_string(hint.token));
}
}
return prepareModel(model, config.preference, config.priority, config.deadlineNs,
config.modelCache, config.dataCache, config.cacheToken, callback);
}
ndk::ScopedAStatus InvalidDevice::prepareModelFromCache(
int64_t /*deadline*/, const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,

View File

@@ -128,8 +128,9 @@ PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -141,30 +142,46 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Prepare
const auto aidlMeasure = NN_TRY(convert(measure));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration,
relocation);
return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration, hints,
extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
PreparedModel::executeInternal(const Request& request, bool measure, int64_t deadline,
int64_t loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
ExecutionResult executionResult;
const auto ret = kPreparedModel->executeSynchronously(request, measure, deadline,
loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronously failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->executeSynchronouslyWithConfig(
request,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronouslyWithConfig failed";
} else {
const auto ret = kPreparedModel->executeSynchronously(
request, measure, deadline, loopTimeoutDuration, &executionResult);
HANDLE_ASTATUS(ret) << "executeSynchronously failed";
}
return handleExecutionResult(executionResult, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
PreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -179,31 +196,45 @@ PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::S
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
return executeFencedInternal(aidlRequest, aidlWaitFor, aidlMeasure, aidlDeadline,
aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence,
relocation);
aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence, hints,
extensionNameToPrefix, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
PreparedModel::executeFencedInternal(const Request& request,
const std::vector<ndk::ScopedFileDescriptor>& waitFor,
bool measure, int64_t deadline, int64_t loopTimeoutDuration,
int64_t timeoutDurationAfterFence,
const hal::utils::RequestRelocation& relocation) const {
PreparedModel::executeFencedInternal(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
int64_t deadline, int64_t loopTimeoutDuration, int64_t timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
FencedExecutionResult result;
const auto ret =
kPreparedModel->executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFenced failed";
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->executeFencedWithConfig(
request, waitFor,
{measure, loopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
deadline, timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFencedWithConfig failed";
} else {
const auto ret = kPreparedModel->executeFenced(request, waitFor, measure, deadline,
loopTimeoutDuration,
timeoutDurationAfterFence, &result);
HANDLE_ASTATUS(ret) << "executeFenced failed";
}
return handleFencedExecutionResult(result, relocation);
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -217,8 +248,14 @@ nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
std::shared_ptr<IExecution> execution;
auto aidlHints = NN_TRY(convert(hints));
auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
const auto ret = kPreparedModel->createReusableExecution(
aidlRequest, aidlMeasure, aidlLoopTimeoutDuration, &execution);
aidlRequest,
{aidlMeasure, aidlLoopTimeoutDuration, std::move(aidlHints),
std::move(aidlExtensionPrefix)},
&execution);
HANDLE_ASTATUS(ret) << "createReusableExecution failed";
return Execution::create(std::move(execution), std::move(relocation));
}
@@ -232,7 +269,7 @@ nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() cons
std::shared_ptr<IBurst> burst;
const auto ret = kPreparedModel->configureExecutionBurst(&burst);
HANDLE_ASTATUS(ret) << "configureExecutionBurst failed";
return Burst::create(std::move(burst));
return Burst::create(std::move(burst), kFeatureLevel);
}
std::any PreparedModel::getUnderlyingResource() const {

View File

@@ -61,7 +61,6 @@ constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<
.powerUsage = std::numeric_limits<float>::max()};
constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles - 1,
.numDataCache = nn::kMaxNumberOfCacheFiles};
constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
std::shared_ptr<MockDevice> createMockDevice() {
@@ -124,6 +123,18 @@ auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
};
}
const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
auto makePreparedModelWithConfigReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
const std::shared_ptr<MockPreparedModel>& preparedModel) {
return [launchStatus, returnStatus, preparedModel](
const Model& /*model*/, const PrepareModelConfig& /*config*/,
const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
};
}
auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
const std::shared_ptr<MockPreparedModel>& preparedModel) {
return [launchStatus, returnStatus, preparedModel](
@@ -560,6 +571,8 @@ TEST_P(DeviceTest, getSupportedOperationsDeadObject) {
}
TEST_P(DeviceTest, prepareModel) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -571,7 +584,7 @@ TEST_P(DeviceTest, prepareModel) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -580,6 +593,8 @@ TEST_P(DeviceTest, prepareModel) {
}
TEST_P(DeviceTest, prepareModelLaunchError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -590,7 +605,7 @@ TEST_P(DeviceTest, prepareModelLaunchError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -598,6 +613,8 @@ TEST_P(DeviceTest, prepareModelLaunchError) {
}
TEST_P(DeviceTest, prepareModelReturnError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -608,7 +625,7 @@ TEST_P(DeviceTest, prepareModelReturnError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -616,6 +633,8 @@ TEST_P(DeviceTest, prepareModelReturnError) {
}
TEST_P(DeviceTest, prepareModelNullptrError) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -626,7 +645,7 @@ TEST_P(DeviceTest, prepareModelNullptrError) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -634,6 +653,8 @@ TEST_P(DeviceTest, prepareModelNullptrError) {
}
TEST_P(DeviceTest, prepareModelTransportFailure) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -643,7 +664,7 @@ TEST_P(DeviceTest, prepareModelTransportFailure) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -651,6 +672,8 @@ TEST_P(DeviceTest, prepareModelTransportFailure) {
}
TEST_P(DeviceTest, prepareModelDeadObject) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -660,7 +683,7 @@ TEST_P(DeviceTest, prepareModelDeadObject) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -668,6 +691,8 @@ TEST_P(DeviceTest, prepareModelDeadObject) {
}
TEST_P(DeviceTest, prepareModelAsyncCrash) {
if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
// setup test
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -681,7 +706,157 @@ TEST_P(DeviceTest, prepareModelAsyncCrash) {
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(DeviceTest, prepareModelWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
mockPreparedModel)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
EXPECT_NE(result.value(), nullptr);
}
TEST_P(DeviceTest, prepareModelWithConfigLaunchError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(
ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigReturnError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(
ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigNullptrError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
nullptr)));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(DeviceTest, prepareModelWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto ret = [&device]() {
DeathMonitor::serviceDied(device->getDeathMonitor());
return ndk::ScopedAStatus::ok();
};
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(ret));
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -32,6 +32,10 @@ class MockBurst final : public BnBurst {
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
(const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadline, ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, releaseMemoryResource, (int64_t memoryIdentifierToken),
(override));
};

View File

@@ -50,6 +50,10 @@ class MockDevice final : public BnDevice {
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback),
(override));
MOCK_METHOD(ndk::ScopedAStatus, prepareModelWithConfig,
(const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback),
(override));
MOCK_METHOD(ndk::ScopedAStatus, prepareModelFromCache,
(int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,

View File

@@ -40,10 +40,19 @@ class MockPreparedModel final : public BnPreparedModel {
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
int64_t duration, FencedExecutionResult* fencedExecutionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
(const Request& request, const ExecutionConfig& config, int64_t deadline,
ExecutionResult* executionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, executeFencedWithConfig,
(const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadline, int64_t duration,
FencedExecutionResult* fencedExecutionResult),
(override));
MOCK_METHOD(ndk::ScopedAStatus, configureExecutionBurst, (std::shared_ptr<IBurst> * burst),
(override));
MOCK_METHOD(ndk::ScopedAStatus, createReusableExecution,
(const Request& request, bool measureTiming, int64_t loopTimeoutDuration,
(const Request& request, const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution),
(override));
};

View File

@@ -70,6 +70,21 @@ auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback
class PreparedModelTest : public VersionedAidlUtilsTestBase {};
const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
auto makeFencedExecutionWithConfigResult(
const std::shared_ptr<MockFencedExecutionCallback>& callback) {
return [callback](const Request& /*request*/,
const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
const ExecutionConfig& /*config*/, int64_t /*deadline*/, int64_t /*duration*/,
FencedExecutionResult* fencedExecutionResult) {
*fencedExecutionResult = FencedExecutionResult{.callback = callback,
.syncFence = ndk::ScopedFileDescriptor(-1)};
return ndk::ScopedAStatus::ok();
};
}
} // namespace
TEST_P(PreparedModelTest, invalidPreparedModel) {
@@ -82,6 +97,8 @@ TEST_P(PreparedModelTest, invalidPreparedModel) {
}
TEST_P(PreparedModelTest, executeSync) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -96,7 +113,7 @@ TEST_P(PreparedModelTest, executeSync) {
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -104,6 +121,8 @@ TEST_P(PreparedModelTest, executeSync) {
}
TEST_P(PreparedModelTest, executeSyncError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -112,7 +131,7 @@ TEST_P(PreparedModelTest, executeSyncError) {
.WillOnce(Invoke(makeGeneralFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -120,6 +139,8 @@ TEST_P(PreparedModelTest, executeSyncError) {
}
TEST_P(PreparedModelTest, executeSyncTransportFailure) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -128,7 +149,7 @@ TEST_P(PreparedModelTest, executeSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -136,6 +157,8 @@ TEST_P(PreparedModelTest, executeSyncTransportFailure) {
}
TEST_P(PreparedModelTest, executeSyncDeadObject) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -144,7 +167,7 @@ TEST_P(PreparedModelTest, executeSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -152,6 +175,8 @@ TEST_P(PreparedModelTest, executeSyncDeadObject) {
}
TEST_P(PreparedModelTest, executeFenced) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -165,7 +190,7 @@ TEST_P(PreparedModelTest, executeFenced) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -181,6 +206,8 @@ TEST_P(PreparedModelTest, executeFenced) {
}
TEST_P(PreparedModelTest, executeFencedCallbackError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -195,7 +222,7 @@ TEST_P(PreparedModelTest, executeFencedCallbackError) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -211,6 +238,8 @@ TEST_P(PreparedModelTest, executeFencedCallbackError) {
}
TEST_P(PreparedModelTest, executeFencedError) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -219,7 +248,7 @@ TEST_P(PreparedModelTest, executeFencedError) {
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -227,6 +256,8 @@ TEST_P(PreparedModelTest, executeFencedError) {
}
TEST_P(PreparedModelTest, executeFencedTransportFailure) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -235,7 +266,7 @@ TEST_P(PreparedModelTest, executeFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -243,6 +274,8 @@ TEST_P(PreparedModelTest, executeFencedTransportFailure) {
}
TEST_P(PreparedModelTest, executeFencedDeadObject) {
if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -251,7 +284,7 @@ TEST_P(PreparedModelTest, executeFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -276,7 +309,7 @@ TEST_P(PreparedModelTest, reusableExecuteSync) {
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -300,7 +333,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncError) {
.WillOnce(Invoke(makeGeneralFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -322,7 +355,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -344,7 +377,7 @@ TEST_P(PreparedModelTest, reusableExecuteSyncDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -372,7 +405,7 @@ TEST_P(PreparedModelTest, reusableExecuteFenced) {
.WillRepeatedly(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -410,7 +443,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedCallbackError) {
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -440,7 +473,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedError) {
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -462,7 +495,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedTransportFailure) {
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -484,7 +517,7 @@ TEST_P(PreparedModelTest, reusableExecuteFencedDeadObject) {
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
const auto createResult = preparedModel->createReusableExecution({}, {}, {});
const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -495,6 +528,206 @@ TEST_P(PreparedModelTest, reusableExecuteFencedDeadObject) {
EXPECT_EQ(computeResult.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, executeSyncWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockExecutionResult = ExecutionResult{
.outputSufficientSize = true,
.outputShapes = {},
.timing = kNoTiming,
};
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(
DoAll(SetArgPointee<3>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
EXPECT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
}
TEST_P(PreparedModelTest, executeSyncWithConfigError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(Invoke(makeGeneralFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeSyncWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeSyncWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, executeFencedWithConfig) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockCallback = MockFencedExecutionCallback::create();
EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
.Times(1)
.WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
const auto& [syncFence, callback] = result.value();
EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
ASSERT_NE(callback, nullptr);
// get results from callback
const auto callbackResult = callback();
ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
<< callbackResult.error().message;
}
TEST_P(PreparedModelTest, executeFencedWithConfigCallbackError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
const auto mockCallback = MockFencedExecutionCallback::create();
EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
.Times(1)
.WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
Invoke(makeStatusOk))));
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_TRUE(result.has_value())
<< "Failed with " << result.error().code << ": " << result.error().message;
const auto& [syncFence, callback] = result.value();
EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
ASSERT_NE(callback, nullptr);
// verify callback failure
const auto callbackResult = callback();
ASSERT_FALSE(callbackResult.has_value());
EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigError) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigTransportFailure) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
TEST_P(PreparedModelTest, executeFencedWithConfigDeadObject) {
if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
const auto result =
preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
TEST_P(PreparedModelTest, configureExecutionBurst) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
@@ -567,13 +800,13 @@ TEST_P(PreparedModelTest, createReusableExecution) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto mockExecution = ndk::SharedRefBase::make<MockExecution>();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(DoAll(SetArgPointee<3>(mockExecution), Invoke(makeStatusOk)));
.WillOnce(DoAll(SetArgPointee<2>(mockExecution), Invoke(makeStatusOk)));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -586,13 +819,13 @@ TEST_P(PreparedModelTest, createReusableExecutionError) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -604,13 +837,13 @@ TEST_P(PreparedModelTest, createReusableExecutionTransportFailure) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -622,13 +855,13 @@ TEST_P(PreparedModelTest, createReusableExecutionDeadObject) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());

View File

@@ -63,6 +63,8 @@ struct TestConfig {
// it is skipped. The field is set to true by default and is set to false in
// quantization coupling tests to suppress skipping a test
bool reportSkipping;
// `useConfig` indicates if a test should use execute*WithConfig functions for the execution.
bool useConfig;
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable)
: executor(executor),
@@ -70,7 +72,8 @@ struct TestConfig {
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(true) {}
reportSkipping(true),
useConfig(false) {}
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable, bool reportSkipping)
: executor(executor),
@@ -78,7 +81,17 @@ struct TestConfig {
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(reportSkipping) {}
reportSkipping(reportSkipping),
useConfig(false) {}
TestConfig(Executor executor, bool measureTiming, OutputType outputType, MemoryType memoryType,
bool reusable, bool reportSkipping, bool useConfig)
: executor(executor),
measureTiming(measureTiming),
outputType(outputType),
memoryType(memoryType),
reusable(reusable),
reportSkipping(reportSkipping),
useConfig(useConfig) {}
};
std::string toString(OutputType type) {
@@ -100,7 +113,8 @@ std::string toString(const TestConfig& config) {
<< ", .measureTiming=" << (config.measureTiming ? "true" : "false")
<< ", .outputType=" << toString(config.outputType)
<< ", .memoryType=" << toString(config.memoryType)
<< ", .reusable=" << (config.reusable ? "true" : "false") << "}";
<< ", .reusable=" << (config.reusable ? "true" : "false")
<< ", .useConfig=" << (config.useConfig ? "true" : "false") << "}";
return ss.str();
}
@@ -587,8 +601,8 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
std::shared_ptr<IExecution> execution;
if (testConfig.reusable) {
const auto ret = preparedModel->createReusableExecution(request, testConfig.measureTiming,
loopTimeoutDurationNs, &execution);
const auto ret = preparedModel->createReusableExecution(
request, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}}, &execution);
ASSERT_TRUE(ret.isOk()) << static_cast<nn::ErrorStatus>(ret.getServiceSpecificError());
ASSERT_NE(nullptr, execution.get());
}
@@ -607,6 +621,10 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
::ndk::ScopedAStatus ret;
if (testConfig.reusable) {
ret = execution->executeSynchronously(kNoDeadline, &executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeSynchronouslyWithConfig(
request, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, &executionResult);
} else {
ret = preparedModel->executeSynchronously(request, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
@@ -649,9 +667,16 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
ExecutionResult executionResult;
// execute
ret = burst->executeSynchronously(request, slots, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
&executionResult);
if (testConfig.useConfig) {
ret = burst->executeSynchronouslyWithConfig(
request, slots,
{testConfig.measureTiming, loopTimeoutDurationNs, {}, {}}, kNoDeadline,
&executionResult);
} else {
ret = burst->executeSynchronously(request, slots, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
&executionResult);
}
ASSERT_TRUE(ret.isOk() || ret.getExceptionCode() == EX_SERVICE_SPECIFIC)
<< ret.getDescription();
if (ret.isOk()) {
@@ -680,6 +705,10 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
::ndk::ScopedAStatus ret;
if (testConfig.reusable) {
ret = execution->executeFenced({}, kNoDeadline, kNoDuration, &executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeFencedWithConfig(
request, {}, {testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, kNoDuration, &executionResult);
} else {
ret = preparedModel->executeFenced(request, {}, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
@@ -697,9 +726,19 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
waitFor.emplace_back(dupFd);
// If a sync fence is returned, try start another run waiting for the sync
// fence.
ret = preparedModel->executeFenced(request, waitFor, testConfig.measureTiming,
kNoDeadline, loopTimeoutDurationNs,
kNoDuration, &executionResult);
if (testConfig.reusable) {
ret = execution->executeFenced(waitFor, kNoDeadline, kNoDuration,
&executionResult);
} else if (testConfig.useConfig) {
ret = preparedModel->executeFencedWithConfig(
request, waitFor,
{testConfig.measureTiming, loopTimeoutDurationNs, {}, {}},
kNoDeadline, kNoDuration, &executionResult);
} else {
ret = preparedModel->executeFenced(
request, waitFor, testConfig.measureTiming, kNoDeadline,
loopTimeoutDurationNs, kNoDuration, &executionResult);
}
ASSERT_TRUE(ret.isOk());
waitForSyncFence(executionResult.syncFence.get());
}
@@ -830,11 +869,13 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
std::vector<Executor> executorList;
std::vector<MemoryType> memoryTypeList;
std::vector<bool> reusableList = {false};
std::vector<bool> useConfigList = {false};
int deviceVersion;
ASSERT_TRUE(device->getInterfaceVersion(&deviceVersion).isOk());
if (deviceVersion >= kMinAidlLevelForFL8) {
reusableList.push_back(true);
useConfigList.push_back(true);
}
switch (testKind) {
@@ -879,11 +920,14 @@ void EvaluatePreparedModel(const std::shared_ptr<IDevice>& device,
for (const Executor executor : executorList) {
for (const MemoryType memoryType : memoryTypeList) {
for (const bool reusable : reusableList) {
if (executor == Executor::BURST && reusable) continue;
const TestConfig testConfig(executor, measureTiming, outputType, memoryType,
reusable);
SCOPED_TRACE(toString(testConfig));
EvaluatePreparedModel(device, preparedModel, testModel, testConfig);
for (const bool useConfig : useConfigList) {
if ((useConfig || executor == Executor::BURST) && reusable) continue;
const TestConfig testConfig(executor, measureTiming, outputType,
memoryType, reusable,
/*reportSkipping=*/true, useConfig);
SCOPED_TRACE(toString(testConfig));
EvaluatePreparedModel(device, preparedModel, testModel, testConfig);
}
}
}
}
@@ -942,6 +986,13 @@ void Execute(const std::shared_ptr<IDevice>& device, const TestModel& testModel,
createPreparedModel(device, model, &preparedModel);
if (preparedModel == nullptr) return;
EvaluatePreparedModel(device, preparedModel, testModel, testKind);
int32_t deviceVersion;
ASSERT_TRUE(device->getInterfaceVersion(&deviceVersion).isOk());
if (deviceVersion >= kMinAidlLevelForFL8) {
createPreparedModel(device, model, &preparedModel, /*reportSkipping*/ true,
/*useConfig*/ true);
EvaluatePreparedModel(device, preparedModel, testModel, testKind);
}
} break;
case TestKind::QUANTIZATION_COUPLING: {
ASSERT_TRUE(testModel.hasQuant8CoupledOperands());

View File

@@ -204,11 +204,23 @@ class InvalidPreparedModel : public BnPreparedModel {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request&, const ExecutionConfig&,
int64_t, ExecutionResult*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus executeFencedWithConfig(const Request&,
const std::vector<ndk::ScopedFileDescriptor>&,
const ExecutionConfig&, int64_t, int64_t,
FencedExecutionResult*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
}
ndk::ScopedAStatus createReusableExecution(const aidl_hal::Request&, bool, int64_t,
ndk::ScopedAStatus createReusableExecution(const aidl_hal::Request&, const ExecutionConfig&,
std::shared_ptr<aidl_hal::IExecution>*) override {
return ndk::ScopedAStatus::fromServiceSpecificError(
static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));

View File

@@ -77,6 +77,28 @@ static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const s
ASSERT_EQ(nullptr, preparedModel.get());
}
static void validatePrepareModelWithConfig(const std::shared_ptr<IDevice>& device,
const std::string& message, const Model& model,
ExecutionPreference preference, Priority priority) {
SCOPED_TRACE(message + " [prepareModelWithConfig]");
std::shared_ptr<PreparedModelCallback> preparedModelCallback =
ndk::SharedRefBase::make<PreparedModelCallback>();
const auto prepareLaunchStatus = device->prepareModelWithConfig(
model, {preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken, {}, {}},
preparedModelCallback);
ASSERT_FALSE(prepareLaunchStatus.isOk());
ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
preparedModelCallback->wait();
ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
ASSERT_EQ(nullptr, preparedModel.get());
}
static bool validExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
@@ -103,6 +125,13 @@ static void validate(const std::shared_ptr<IDevice>& device, const std::string&
}
validatePrepareModel(device, message, model, preference, priority);
int32_t aidlVersion;
ASSERT_TRUE(device->getInterfaceVersion(&aidlVersion).isOk());
if (aidlVersion >= kMinAidlLevelForFL8) {
// prepareModelWithConfig must satisfy all requirements enforced by prepareModel.
validatePrepareModelWithConfig(device, message, model, preference, priority);
}
}
static uint32_t addOperand(Model* model) {

View File

@@ -45,7 +45,7 @@ static void validateReusableExecution(const std::shared_ptr<IPreparedModel>& pre
{
SCOPED_TRACE(message + " [createReusableExecution]");
const auto createStatus = preparedModel->createReusableExecution(
request, measure, kOmittedTimeoutDuration, &execution);
request, {measure, kOmittedTimeoutDuration, {}, {}}, &execution);
if (!createStatus.isOk()) {
ASSERT_EQ(createStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(createStatus.getServiceSpecificError()),
@@ -149,10 +149,59 @@ static void validate(const std::shared_ptr<IPreparedModel>& preparedModel,
int32_t aidlVersion;
ASSERT_TRUE(preparedModel->getInterfaceVersion(&aidlVersion).isOk());
if (aidlVersion < kMinAidlLevelForFL8) {
return;
}
// validate reusable execution
if (aidlVersion >= kMinAidlLevelForFL8) {
validateReusableExecution(preparedModel, message, request, measure);
validateReusableExecution(preparedModel, message, request, measure);
// synchronous with empty hints
{
SCOPED_TRACE(message + " [executeSynchronouslyWithConfig]");
ExecutionResult executionResult;
const auto executeStatus = preparedModel->executeSynchronouslyWithConfig(
request, {measure, kOmittedTimeoutDuration, {}, {}}, kNoDeadline, &executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
// fenced with empty hints
{
SCOPED_TRACE(message + " [executeFencedWithConfig]");
FencedExecutionResult executionResult;
const auto executeStatus = preparedModel->executeFencedWithConfig(
request, {}, {false, kOmittedTimeoutDuration, {}, {}}, kNoDeadline, kNoDuration,
&executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
// burst with empty hints
{
SCOPED_TRACE(message + " [burst executeSynchronouslyWithConfig]");
// create burst
std::shared_ptr<IBurst> burst;
auto ret = preparedModel->configureExecutionBurst(&burst);
ASSERT_TRUE(ret.isOk()) << ret.getDescription();
ASSERT_NE(nullptr, burst.get());
// use -1 for all memory identifier tokens
const std::vector<int64_t> slots(request.pools.size(), -1);
ExecutionResult executionResult;
const auto executeStatus = burst->executeSynchronouslyWithConfig(
request, slots, {measure, kOmittedTimeoutDuration, {}, {}}, kNoDeadline,
&executionResult);
ASSERT_FALSE(executeStatus.isOk());
ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
ErrorStatus::INVALID_ARGUMENT);
}
}

View File

@@ -41,7 +41,8 @@ using implementation::PreparedModelCallback;
// internal helper function
void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& model,
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping) {
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping,
bool useConfig) {
ASSERT_NE(nullptr, preparedModel);
*preparedModel = nullptr;
@@ -56,11 +57,25 @@ void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& mo
// launch prepare model
const std::shared_ptr<PreparedModelCallback> preparedModelCallback =
ndk::SharedRefBase::make<PreparedModelCallback>();
const auto prepareLaunchStatus =
device->prepareModel(model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority,
kNoDeadline, {}, {}, kEmptyCacheToken, preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
if (useConfig) {
const auto prepareLaunchStatus =
device->prepareModelWithConfig(model,
{ExecutionPreference::FAST_SINGLE_ANSWER,
kDefaultPriority,
kNoDeadline,
{},
{},
kEmptyCacheToken,
{},
{}},
preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
} else {
const auto prepareLaunchStatus = device->prepareModel(
model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, kNoDeadline, {},
{}, kEmptyCacheToken, preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk()) << prepareLaunchStatus.getDescription();
}
// retrieve prepared model
preparedModelCallback->wait();
const ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();

View File

@@ -51,8 +51,8 @@ std::string printNeuralNetworksAidlTest(
// Create an IPreparedModel object. If the model cannot be prepared,
// "preparedModel" will be nullptr instead.
void createPreparedModel(const std::shared_ptr<IDevice>& device, const Model& model,
std::shared_ptr<IPreparedModel>* preparedModel,
bool reportSkipping = true);
std::shared_ptr<IPreparedModel>* preparedModel, bool reportSkipping = true,
bool useConfig = false);
enum class Executor { SYNC, BURST, FENCED };

View File

@@ -46,6 +46,11 @@ class Burst : public BnBurst {
bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus executeSynchronouslyWithConfig(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadlineNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
class ThreadSafeMemoryCache {

View File

@@ -31,6 +31,7 @@
#include <aidl/android/hardware/neuralnetworks/IPreparedModelParcel.h>
#include <aidl/android/hardware/neuralnetworks/Model.h>
#include <aidl/android/hardware/neuralnetworks/NumberOfCacheFiles.h>
#include <aidl/android/hardware/neuralnetworks/PrepareModelConfig.h>
#include <aidl/android/hardware/neuralnetworks/Priority.h>
#include <android/binder_auto_utils.h>
#include <nnapi/IDevice.h>
@@ -72,6 +73,9 @@ class Device : public BnDevice {
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
protected:
const ::android::nn::SharedDevice kDevice;

View File

@@ -51,9 +51,17 @@ class PreparedModel : public BnPreparedModel {
int64_t loopTimeoutDurationNs, int64_t durationNs,
FencedExecutionResult* executionResult) override;
ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>* burst) override;
ndk::ScopedAStatus createReusableExecution(const Request& request, bool measureTiming,
int64_t loopTimeoutDurationNs,
ndk::ScopedAStatus createReusableExecution(const Request& request,
const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution) override;
ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request& request,
const ExecutionConfig& config,
int64_t deadlineNs,
ExecutionResult* executionResult) override;
ndk::ScopedAStatus executeFencedWithConfig(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
FencedExecutionResult* executionResult) override;
::android::nn::SharedPreparedModel getUnderlyingPreparedModel() const;

View File

@@ -93,7 +93,8 @@ std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached(
nn::ExecutionResult<ExecutionResult> executeSynchronously(
const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request,
const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs) {
int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
if (request.pools.size() != memoryIdentifierTokens.size()) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
<< "request.pools.size() != memoryIdentifierTokens.size()";
@@ -107,11 +108,13 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache);
const auto result =
burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
@@ -155,7 +158,24 @@ ndk::ScopedAStatus Burst::executeSynchronously(const Request& request,
ExecutionResult* executionResult) {
auto result =
adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens,
measureTiming, deadlineNs, loopTimeoutDurationNs);
measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {});
if (!result.has_value()) {
auto [message, code, _] = std::move(result).error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(
*kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming,
deadlineNs, config.loopTimeoutDurationNs, config.executionHints,
config.extensionNameToPrefix);
if (!result.has_value()) {
auto [message, code, _] = std::move(result).error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);

View File

@@ -148,13 +148,14 @@ void notify(IPreparedModelCallback* callback, PrepareModelResult result) {
}
}
nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Executor& executor,
const Model& model, ExecutionPreference preference,
Priority priority, int64_t deadlineNs,
const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) {
nn::GeneralResult<void> prepareModel(
const nn::SharedDevice& device, const Executor& executor, const Model& model,
ExecutionPreference preference, Priority priority, int64_t deadlineNs,
const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache, const std::vector<uint8_t>& token,
const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
const std::shared_ptr<IPreparedModelCallback>& callback) {
if (callback.get() == nullptr) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid callback";
}
@@ -166,12 +167,16 @@ nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Execu
auto nnModelCache = NN_TRY(convertInput(modelCache));
auto nnDataCache = NN_TRY(convertInput(dataCache));
const auto nnToken = NN_TRY(convertCacheToken(token));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
Task task = [device, nnModel = std::move(nnModel), nnPreference, nnPriority, nnDeadline,
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline,
nnModelCache, nnDataCache, nnToken);
nnToken, nnHints = std::move(nnHints),
nnExtensionNameToPrefix = std::move(nnExtensionNameToPrefix), callback] {
auto result =
device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline, nnModelCache,
nnDataCache, nnToken, nnHints, nnExtensionNameToPrefix);
notify(callback.get(), std::move(result));
};
executor(std::move(task), nnDeadline);
@@ -273,8 +278,9 @@ ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) {
const auto result = adapter::prepareModel(kDevice, kExecutor, model, preference, priority,
deadlineNs, modelCache, dataCache, token, callback);
const auto result =
adapter::prepareModel(kDevice, kExecutor, model, preference, priority, deadlineNs,
modelCache, dataCache, token, {}, {}, callback);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -301,4 +307,21 @@ ndk::ScopedAStatus Device::prepareModelFromCache(
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus Device::prepareModelWithConfig(
const Model& model, const PrepareModelConfig& config,
const std::shared_ptr<IPreparedModelCallback>& callback) {
const auto result = adapter::prepareModel(
kDevice, kExecutor, model, config.preference, config.priority, config.deadlineNs,
config.modelCache, config.dataCache, config.cacheToken, config.compilationHints,
config.extensionNameToPrefix, callback);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
callback->notify(aidlCode, nullptr);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
return ndk::ScopedAStatus::ok();
}
} // namespace aidl::android::hardware::neuralnetworks::adapter

View File

@@ -118,17 +118,20 @@ nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationN
return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
}
nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedModel& preparedModel,
const Request& request,
bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs) {
nn::ExecutionResult<ExecutionResult> executeSynchronously(
const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
int64_t deadlineNs, int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
const auto result =
preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
@@ -147,16 +150,21 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedMod
nn::GeneralResult<FencedExecutionResult> executeFenced(
const nn::IPreparedModel& preparedModel, const Request& request,
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs) {
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs));
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
auto [syncFence, executeFencedInfoCallback] = NN_TRY(preparedModel.executeFenced(
nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration));
nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration,
nnHints, nnExtensionNameToPrefix));
ndk::ScopedFileDescriptor fileDescriptor;
if (syncFence.hasFd()) {
@@ -171,11 +179,16 @@ nn::GeneralResult<FencedExecutionResult> executeFenced(
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::IPreparedModel& preparedModel, const Request& request, bool measureTiming,
int64_t loopTimeoutDurationNs) {
int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration);
auto nnHints = NN_TRY(convertInput(hints));
auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
return preparedModel.createReusableExecution(nnRequest, nnMeasureTiming, nnLoopTimeoutDuration,
nnHints, nnExtensionNameToPrefix);
}
nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IExecution& execution,
@@ -231,7 +244,7 @@ ndk::ScopedAStatus PreparedModel::executeSynchronously(const Request& request, b
int64_t loopTimeoutDurationNs,
ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(*kPreparedModel, request, measureTiming, deadlineNs,
loopTimeoutDurationNs);
loopTimeoutDurationNs, {}, {});
if (!result.has_value()) {
const auto& [message, code, _] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -247,7 +260,41 @@ ndk::ScopedAStatus PreparedModel::executeFenced(
bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
FencedExecutionResult* executionResult) {
auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, measureTiming,
deadlineNs, loopTimeoutDurationNs, durationNs);
deadlineNs, loopTimeoutDurationNs, durationNs, {}, {});
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::executeSynchronouslyWithConfig(const Request& request,
const ExecutionConfig& config,
int64_t deadlineNs,
ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(
*kPreparedModel, request, config.measureTiming, deadlineNs,
config.loopTimeoutDurationNs, config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code, _] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::executeFencedWithConfig(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
FencedExecutionResult* executionResult) {
auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, config.measureTiming,
deadlineNs, config.loopTimeoutDurationNs, durationNs,
config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
@@ -275,11 +322,11 @@ nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
}
ndk::ScopedAStatus PreparedModel::createReusableExecution(const Request& request,
bool measureTiming,
int64_t loopTimeoutDurationNs,
const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution) {
auto result = adapter::createReusableExecution(*kPreparedModel, request, measureTiming,
loopTimeoutDurationNs);
auto result = adapter::createReusableExecution(
*kPreparedModel, request, config.measureTiming, config.loopTimeoutDurationNs,
config.executionHints, config.extensionNameToPrefix);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);

View File

@@ -250,7 +250,7 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> Burst:
nn::MeasureTiming canonicalMeasure = NN_TRY(nn::convert(measure));
const auto [outputShapes, timing] =
NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure, {}, {}));
NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure, {}, {}, {}, {}));
return std::make_pair(NN_TRY(V1_2::utils::convert(outputShapes)),
NN_TRY(V1_2::utils::convert(timing)));

View File

@@ -135,7 +135,7 @@ nn::GeneralResult<void> prepareModel(const nn::SharedDevice& device, const Execu
Task task = [device, nnModel = std::move(nnModel), executor, callback] {
auto result = device->prepareModel(nnModel, nn::ExecutionPreference::DEFAULT,
nn::Priority::DEFAULT, {}, {}, {}, {});
nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -155,8 +155,8 @@ nn::GeneralResult<void> prepareModel_1_1(const nn::SharedDevice& device, const E
const auto nnPreference = NN_TRY(convertInput(preference));
Task task = [device, nnModel = std::move(nnModel), nnPreference, executor, callback] {
auto result =
device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {}, {}, {}, {});
auto result = device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {}, {}, {},
{}, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -185,7 +185,7 @@ nn::GeneralResult<void> prepareModel_1_2(const nn::SharedDevice& device, const E
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, executor, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nn::Priority::DEFAULT, {},
nnModelCache, nnDataCache, nnToken);
nnModelCache, nnDataCache, nnToken, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), {});
@@ -215,7 +215,7 @@ nn::GeneralResult<void> prepareModel_1_3(
nnModelCache = std::move(nnModelCache), nnDataCache = std::move(nnDataCache),
nnToken, executor, callback] {
auto result = device->prepareModel(nnModel, nnPreference, nnPriority, nnDeadline,
nnModelCache, nnDataCache, nnToken);
nnModelCache, nnDataCache, nnToken, {}, {});
notify(callback.get(), std::move(result), executor);
};
executor(std::move(task), nnDeadline);

View File

@@ -159,7 +159,7 @@ nn::GeneralResult<void> execute(const nn::SharedPreparedModel& preparedModel,
}
Task task = [preparedModel, nnRequest = std::move(nnRequest), callback] {
auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {});
auto result = preparedModel->execute(nnRequest, nn::MeasureTiming::NO, {}, {}, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), {});
@@ -185,7 +185,7 @@ nn::GeneralResult<void> execute_1_2(const nn::SharedPreparedModel& preparedModel
}
Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, callback] {
auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {});
auto result = preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), {});
@@ -216,8 +216,8 @@ nn::GeneralResult<void> execute_1_3(const nn::SharedPreparedModel& preparedModel
Task task = [preparedModel, nnRequest = std::move(nnRequest), nnMeasure, nnDeadline,
nnLoopTimeoutDuration, callback] {
auto result =
preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration);
auto result = preparedModel->execute(nnRequest, nnMeasure, nnDeadline,
nnLoopTimeoutDuration, {}, {});
notify(callback.get(), std::move(result));
};
executor(std::move(task), nnDeadline);
@@ -232,7 +232,7 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> execut
const auto nnMeasure = NN_TRY(convertInput(measure));
const auto [outputShapes, timing] =
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}));
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, {}, {}, {}, {}));
auto hidlOutputShapes = NN_TRY(V1_2::utils::convert(outputShapes));
const auto hidlTiming = NN_TRY(V1_2::utils::convert(timing));
@@ -248,8 +248,8 @@ nn::ExecutionResult<std::pair<hidl_vec<V1_2::OutputShape>, V1_2::Timing>> execut
const auto nnDeadline = NN_TRY(convertInput(deadline));
const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
const auto [outputShapes, timing] =
NN_TRY(preparedModel->execute(nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration));
const auto [outputShapes, timing] = NN_TRY(preparedModel->execute(
nnRequest, nnMeasure, nnDeadline, nnLoopTimeoutDuration, {}, {}));
auto hidlOutputShapes = NN_TRY(V1_3::utils::convert(outputShapes));
const auto hidlTiming = NN_TRY(V1_3::utils::convert(timing));
@@ -293,8 +293,9 @@ nn::GeneralResult<std::pair<hidl_handle, sp<V1_3::IFencedExecutionCallback>>> ex
const auto nnLoopTimeoutDuration = NN_TRY(convertInput(loopTimeoutDuration));
const auto nnDuration = NN_TRY(convertInput(duration));
auto [syncFence, executeFencedCallback] = NN_TRY(preparedModel->executeFenced(
nnRequest, nnWaitFor, nnMeasure, nnDeadline, nnLoopTimeoutDuration, nnDuration));
auto [syncFence, executeFencedCallback] =
NN_TRY(preparedModel->executeFenced(nnRequest, nnWaitFor, nnMeasure, nnDeadline,
nnLoopTimeoutDuration, nnDuration, {}, {}));
auto hidlSyncFence = NN_TRY(V1_3::utils::convert(syncFence.getSharedHandle()));
auto hidlExecuteFencedCallback = sp<FencedExecutionCallback>::make(executeFencedCallback);

View File

@@ -33,12 +33,15 @@ class InvalidBurst final : public nn::IBurst {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
};
} // namespace android::hardware::neuralnetworks::utils

View File

@@ -52,8 +52,9 @@ class InvalidDevice final : public nn::IDevice {
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,

View File

@@ -31,18 +31,23 @@ class InvalidPreparedModel final : public nn::IPreparedModel {
public:
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;

View File

@@ -48,18 +48,23 @@ class ResilientBurst final : public nn::IBurst,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
private:
bool isValidInternal() const EXCLUDES(mMutex);
nn::GeneralResult<nn::SharedExecution> createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
const Factory kMakeBurst;
mutable std::mutex mMutex;

View File

@@ -65,8 +65,9 @@ class ResilientDevice final : public nn::IDevice,
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache,
const nn::CacheToken& token) const override;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
@@ -83,7 +84,9 @@ class ResilientDevice final : public nn::IDevice,
nn::GeneralResult<nn::SharedPreparedModel> prepareModelInternal(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const;
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCacheInternal(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const;

View File

@@ -49,18 +49,23 @@ class ResilientPreparedModel final : public nn::IPreparedModel,
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const override;
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const override;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
@@ -70,7 +75,9 @@ class ResilientPreparedModel final : public nn::IPreparedModel,
bool isValidInternal() const EXCLUDES(mMutex);
nn::GeneralResult<nn::SharedExecution> createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const;
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& metaData,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurstInternal() const;
const Factory kMakePreparedModel;

View File

@@ -34,13 +34,17 @@ InvalidBurst::OptionalCacheHold InvalidBurst::cacheMemory(
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> InvalidBurst::execute(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidBurst";
}
nn::GeneralResult<nn::SharedExecution> InvalidBurst::createReusableExecution(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidBurst";
}

View File

@@ -84,7 +84,9 @@ nn::GeneralResult<nn::SharedPreparedModel> InvalidDevice::prepareModel(
const nn::Model& /*model*/, nn::ExecutionPreference /*preference*/,
nn::Priority /*priority*/, nn::OptionalTimePoint /*deadline*/,
const std::vector<nn::SharedHandle>& /*modelCache*/,
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/) const {
const std::vector<nn::SharedHandle>& /*dataCache*/, const nn::CacheToken& /*token*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidDevice";
}

View File

@@ -27,9 +27,12 @@
namespace android::hardware::neuralnetworks::utils {
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
InvalidPreparedModel::execute(const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
InvalidPreparedModel::execute(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}
@@ -38,13 +41,17 @@ InvalidPreparedModel::executeFenced(
const nn::Request& /*request*/, const std::vector<nn::SyncFence>& /*waitFor*/,
nn::MeasureTiming /*measure*/, const nn::OptionalTimePoint& /*deadline*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
const nn::OptionalDuration& /*timeoutDurationAfterFence*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}
nn::GeneralResult<nn::SharedExecution> InvalidPreparedModel::createReusableExecution(
const nn::Request& /*request*/, nn::MeasureTiming /*measure*/,
const nn::OptionalDuration& /*loopTimeoutDuration*/) const {
const nn::OptionalDuration& /*loopTimeoutDuration*/,
const std::vector<nn::TokenValuePair>& /*hints*/,
const std::vector<nn::ExtensionNameAndPrefix>& /*extensionNameToPrefix*/) const {
return NN_ERROR() << "InvalidPreparedModel";
}

View File

@@ -105,37 +105,49 @@ ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> ResilientBurst::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const auto fn = [&request, measure, deadline, loopTimeoutDuration](const nn::IBurst& burst) {
return burst.execute(request, measure, deadline, loopTimeoutDuration);
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, measure, deadline, loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IBurst& burst) {
return burst.execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientExecution::Factory makeExecution =
[burst = std::move(self), request, measure, loopTimeoutDuration] {
return burst->createReusableExecutionInternal(request, measure, loopTimeoutDuration);
ResilientExecution::Factory makeExecution = [burst = std::move(self), request, measure,
loopTimeoutDuration, &hints,
&extensionNameToPrefix] {
return burst->createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return ResilientExecution::create(std::move(makeExecution));
#else
return createReusableExecutionInternal(request, measure, loopTimeoutDuration);
return createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
#endif
}
nn::GeneralResult<nn::SharedExecution> ResilientBurst::createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidExecution>();
}
const auto fn = [&request, measure, &loopTimeoutDuration](const nn::IBurst& burst) {
return burst.createReusableExecution(request, measure, loopTimeoutDuration);
const auto fn = [&request, measure, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IBurst& burst) {
return burst.createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}

View File

@@ -179,19 +179,21 @@ nn::GeneralResult<std::vector<bool>> ResilientDevice::getSupportedOperations(
nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), model,
preference, priority, deadline, modelCache,
dataCache, token] {
dataCache, token, hints, extensionNameToPrefix] {
return device->prepareModelInternal(model, preference, priority, deadline, modelCache,
dataCache, token);
dataCache, token, hints, extensionNameToPrefix);
};
return ResilientPreparedModel::create(std::move(makePreparedModel));
#else
return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache,
token);
return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache, token,
hints, extensionNameToPrefix);
#endif
}
@@ -234,14 +236,16 @@ bool ResilientDevice::isValidInternal() const {
nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelInternal(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidPreparedModel>();
}
const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache,
&token](const nn::IDevice& device) {
const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache, &token,
&hints, &extensionNameToPrefix](const nn::IDevice& device) {
return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
token);
token, hints, extensionNameToPrefix);
};
return protect(*this, fn, /*blocking=*/false);
}

View File

@@ -104,43 +104,53 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration) const {
const auto fn = [&request, measure, &deadline,
&loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
ResilientPreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, measure, &deadline, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.execute(request, measure, deadline, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
ResilientPreparedModel::executeFenced(const nn::Request& request,
const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure,
const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence) const {
ResilientPreparedModel::executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
const nn::OptionalDuration& timeoutDurationAfterFence,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
&timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
&timeoutDurationAfterFence, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
timeoutDurationAfterFence);
timeoutDurationAfterFence, hints, extensionNameToPrefix);
};
return protect(*this, fn);
}
nn::GeneralResult<nn::SharedExecution> ResilientPreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
#if 0
auto self = shared_from_this();
ResilientExecution::Factory makeExecution =
[preparedModel = std::move(self), request, measure, loopTimeoutDuration] {
return preparedModel->createReusableExecutionInternal(request, measure, loopTimeoutDuration);
ResilientExecution::Factory makeExecution = [preparedModel = std::move(self), request, measure,
loopTimeoutDuration, hints,
extensionNameToPrefix] {
return preparedModel->createReusableExecutionInternal(request, measure, loopTimeoutDuration,
hints, extensionNameToPrefix);
};
return ResilientExecution::create(std::move(makeExecution));
#else
return createReusableExecutionInternal(request, measure, loopTimeoutDuration);
return createReusableExecutionInternal(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
#endif
}
@@ -159,13 +169,16 @@ nn::GeneralResult<nn::SharedBurst> ResilientPreparedModel::configureExecutionBur
nn::GeneralResult<nn::SharedExecution> ResilientPreparedModel::createReusableExecutionInternal(
const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration) const {
const nn::OptionalDuration& loopTimeoutDuration,
const std::vector<nn::TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
if (!isValidInternal()) {
return std::make_shared<const InvalidExecution>();
}
const auto fn = [&request, measure,
&loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
return preparedModel.createReusableExecution(request, measure, loopTimeoutDuration);
const auto fn = [&request, measure, &loopTimeoutDuration, &hints,
&extensionNameToPrefix](const nn::IPreparedModel& preparedModel) {
return preparedModel.createReusableExecution(request, measure, loopTimeoutDuration, hints,
extensionNameToPrefix);
};
return protect(*this, fn);
}

View File

@@ -39,7 +39,9 @@ class MockDevice final : public IDevice {
MOCK_METHOD(GeneralResult<SharedPreparedModel>, prepareModel,
(const Model& model, ExecutionPreference preference, Priority priority,
OptionalTimePoint deadline, const std::vector<SharedHandle>& modelCache,
const std::vector<SharedHandle>& dataCache, const CacheToken& token),
const std::vector<SharedHandle>& dataCache, const CacheToken& token,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD(GeneralResult<SharedPreparedModel>, prepareModelFromCache,
(OptionalTimePoint deadline, const std::vector<SharedHandle>& modelCache,

View File

@@ -27,17 +27,23 @@ class MockPreparedModel final : public IPreparedModel {
public:
MOCK_METHOD((ExecutionResult<std::pair<std::vector<OutputShape>, Timing>>), execute,
(const Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
const OptionalDuration& loopTimeoutDuration),
const OptionalDuration& loopTimeoutDuration,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD((GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>>), executeFenced,
(const Request& request, const std::vector<SyncFence>& waitFor,
MeasureTiming measure, const OptionalTimePoint& deadline,
const OptionalDuration& loopTimeoutDuration,
const OptionalDuration& timeoutDurationAfterFence),
const OptionalDuration& timeoutDurationAfterFence,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD((GeneralResult<SharedExecution>), createReusableExecution,
(const nn::Request& request, nn::MeasureTiming measure,
const nn::OptionalDuration& loopTimeoutDuration),
(const Request& request, MeasureTiming measure,
const OptionalDuration& loopTimeoutDuration,
const std::vector<TokenValuePair>& hints,
const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix),
(const, override));
MOCK_METHOD(GeneralResult<SharedBurst>, configureExecutionBurst, (), (const, override));
MOCK_METHOD(std::any, getUnderlyingResource, (), (const, override));

View File

@@ -309,12 +309,12 @@ TEST(ResilientDeviceTest, prepareModel) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(mockPreparedModel));
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -324,12 +324,12 @@ TEST(ResilientDeviceTest, prepareModel) {
TEST(ResilientDeviceTest, prepareModelError) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -339,13 +339,13 @@ TEST(ResilientDeviceTest, prepareModelError) {
TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -355,18 +355,18 @@ TEST(ResilientDeviceTest, prepareModelDeadObjectFailedRecovery) {
TEST(ResilientDeviceTest, prepareModelDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockDevice, mockDeviceFactory, device] = setup();
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
const auto recoveredMockDevice = createConfiguredMockDevice();
const auto mockPreparedModel = std::make_shared<const nn::MockPreparedModel>();
EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _))
EXPECT_CALL(*recoveredMockDevice, prepareModel(_, _, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(mockPreparedModel));
EXPECT_CALL(*mockDeviceFactory, Call(false)).Times(1).WillOnce(Return(recoveredMockDevice));
// run test
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
const auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -679,7 +679,7 @@ TEST(ResilientDeviceTest, recoverCacheMismatchInvalidPrepareModel) {
device->recover(mockDevice.get(), /*blocking=*/false);
// run test
auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {});
auto result = device->prepareModel({}, {}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())

View File

@@ -104,12 +104,12 @@ TEST(ResilientPreparedModelTest, getPreparedModel) {
TEST(ResilientPreparedModelTest, execute) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _))
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoExecutionError));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -119,10 +119,12 @@ TEST(ResilientPreparedModelTest, execute) {
TEST(ResilientPreparedModelTest, executeError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnGeneralFailure);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -132,12 +134,12 @@ TEST(ResilientPreparedModelTest, executeError) {
TEST(ResilientPreparedModelTest, executeDeadObjectFailedRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
constexpr auto ret = [] { return nn::error(nn::ErrorStatus::GENERAL_FAILURE); };
EXPECT_CALL(*mockPreparedModelFactory, Call()).Times(1).WillOnce(ret);
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -147,9 +149,9 @@ TEST(ResilientPreparedModelTest, executeDeadObjectFailedRecovery) {
TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModel, execute(_, _, _, _, _, _)).Times(1).WillOnce(kReturnDeadObject);
const auto recoveredMockPreparedModel = createConfiguredMockPreparedModel();
EXPECT_CALL(*recoveredMockPreparedModel, execute(_, _, _, _))
EXPECT_CALL(*recoveredMockPreparedModel, execute(_, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoExecutionError));
EXPECT_CALL(*mockPreparedModelFactory, Call())
@@ -157,7 +159,7 @@ TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
.WillOnce(Return(recoveredMockPreparedModel));
// run test
const auto result = preparedModel->execute({}, {}, {}, {});
const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -167,12 +169,12 @@ TEST(ResilientPreparedModelTest, executeDeadObjectSuccessfulRecovery) {
TEST(ResilientPreparedModelTest, executeFenced) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoFencedExecutionError));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -182,12 +184,12 @@ TEST(ResilientPreparedModelTest, executeFenced) {
TEST(ResilientPreparedModelTest, executeFencedError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -197,13 +199,13 @@ TEST(ResilientPreparedModelTest, executeFencedError) {
TEST(ResilientPreparedModelTest, executeFencedDeadObjectFailedRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
EXPECT_CALL(*mockPreparedModelFactory, Call()).Times(1).WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -213,11 +215,11 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectFailedRecovery) {
TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(kReturnDeadObject);
const auto recoveredMockPreparedModel = createConfiguredMockPreparedModel();
EXPECT_CALL(*recoveredMockPreparedModel, executeFenced(_, _, _, _, _, _))
EXPECT_CALL(*recoveredMockPreparedModel, executeFenced(_, _, _, _, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoFencedExecutionError));
EXPECT_CALL(*mockPreparedModelFactory, Call())
@@ -225,7 +227,7 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
.WillOnce(Return(recoveredMockPreparedModel));
// run test
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -235,12 +237,12 @@ TEST(ResilientPreparedModelTest, executeFencedDeadObjectSuccessfulRecovery) {
TEST(ResilientPreparedModelTest, createReusableExecution) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _, _))
.Times(1)
.WillOnce(Return(kNoCreateReusableExecutionError));
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -250,12 +252,12 @@ TEST(ResilientPreparedModelTest, createReusableExecution) {
TEST(ResilientPreparedModelTest, createReusableExecutionError) {
// setup call
const auto [mockPreparedModel, mockPreparedModelFactory, preparedModel] = setup();
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _, _))
.Times(1)
.WillOnce(kReturnGeneralFailure);
// run test
const auto result = preparedModel->createReusableExecution({}, {}, {});
const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());