diff --git a/neuralnetworks/aidl/vts/functional/CompilationCachingTests.cpp b/neuralnetworks/aidl/vts/functional/CompilationCachingTests.cpp index e0b529f280..94ce5c1130 100644 --- a/neuralnetworks/aidl/vts/functional/CompilationCachingTests.cpp +++ b/neuralnetworks/aidl/vts/functional/CompilationCachingTests.cpp @@ -357,16 +357,40 @@ class CompilationCachingTestBase : public testing::Test { return false; } + // If fallbackModel is not provided, call prepareModelFromCache. + // If fallbackModel is provided, and prepareModelFromCache returns GENERAL_FAILURE, + // then prepareModel(fallbackModel) will be called. + // This replicates the behaviour of the runtime when loading a model from cache. + // NNAPI Shim depends on this behaviour and may try to load the model from cache in + // prepareModel (shim needs model information when loading from cache). void prepareModelFromCache(const std::vector& modelCache, const std::vector& dataCache, - std::shared_ptr* preparedModel, - ErrorStatus* status) { + std::shared_ptr* preparedModel, ErrorStatus* status, + const Model* fallbackModel = nullptr) { // Launch prepare model from cache. std::shared_ptr preparedModelCallback = ndk::SharedRefBase::make(); std::vector cacheToken(std::begin(mToken), std::end(mToken)); - const auto prepareLaunchStatus = kDevice->prepareModelFromCache( + auto prepareLaunchStatus = kDevice->prepareModelFromCache( kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback); + + // The shim does not support prepareModelFromCache() properly, but it + // will still attempt to create a model from cache when modelCache or + // dataCache is provided in prepareModel(). Instead of failing straight + // away, we try to utilize that other code path when fallbackModel is + // set. Note that we cannot verify whether the returned model was + // actually prepared from cache in that case. + if (!prepareLaunchStatus.isOk() && + prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC && + static_cast(prepareLaunchStatus.getServiceSpecificError()) == + ErrorStatus::GENERAL_FAILURE && + mIsCachingSupported && fallbackModel != nullptr) { + preparedModelCallback = ndk::SharedRefBase::make(); + prepareLaunchStatus = kDevice->prepareModel( + *fallbackModel, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, + kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback); + } + ASSERT_TRUE(prepareLaunchStatus.isOk() || prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC) << "prepareLaunchStatus: " << prepareLaunchStatus.getDescription(); @@ -382,6 +406,42 @@ class CompilationCachingTestBase : public testing::Test { *preparedModel = preparedModelCallback->getPreparedModel(); } + // Replicate behaviour of runtime when loading model from cache. + // Test if prepareModelFromCache behaves correctly when faced with bad + // arguments. If prepareModelFromCache is not supported (GENERAL_FAILURE), + // it attempts to call prepareModel with same arguments, which is expected either + // to not support the model (GENERAL_FAILURE) or return a valid model. + void verifyModelPreparationBehaviour(const std::vector& modelCache, + const std::vector& dataCache, + const Model* model, const TestModel& testModel) { + std::shared_ptr preparedModel; + ErrorStatus status; + + // Verify that prepareModelFromCache fails either due to bad + // arguments (INVALID_ARGUMENT) or GENERAL_FAILURE if not supported. + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/nullptr); + if (status != ErrorStatus::INVALID_ARGUMENT) { + ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); + } + ASSERT_EQ(preparedModel, nullptr); + + // If caching is not supported, attempt calling prepareModel. + if (status == ErrorStatus::GENERAL_FAILURE) { + // Fallback with prepareModel should succeed regardless of cache files + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/model); + // Unless caching is not supported? + if (status != ErrorStatus::GENERAL_FAILURE) { + // But if it is, we should see a valid model. + ASSERT_EQ(status, ErrorStatus::NONE); + ASSERT_NE(preparedModel, nullptr); + EvaluatePreparedModel(kDevice, preparedModel, testModel, + /*testKind=*/TestKind::GENERAL); + } + } + } + // Absolute path to the temporary cache directory. std::string mCacheDir; @@ -397,7 +457,7 @@ class CompilationCachingTestBase : public testing::Test { uint8_t mToken[static_cast(IDevice::BYTE_SIZE_OF_CACHE_TOKEN)] = {}; uint32_t mNumModelCache; uint32_t mNumDataCache; - uint32_t mIsCachingSupported; + bool mIsCachingSupported; const std::shared_ptr kDevice; // The primary data type of the testModel. @@ -438,7 +498,8 @@ TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) { std::vector modelCache, dataCache; createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/&model); if (!mIsCachingSupported) { ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); ASSERT_EQ(preparedModel, nullptr); @@ -498,7 +559,8 @@ TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) { for (uint32_t i = 0; i < dataCache.size(); i++) { ASSERT_GE(read(dataCache[i].get(), &placeholderByte, 1), 0); } - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/&model); if (!mIsCachingSupported) { ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); ASSERT_EQ(preparedModel, nullptr); @@ -536,13 +598,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of model cache files smaller than mNumModelCache. @@ -560,13 +616,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of data cache files greater than mNumDataCache. @@ -583,13 +633,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of data cache files smaller than mNumDataCache. @@ -607,13 +651,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } } @@ -633,68 +671,48 @@ TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) { // Test with number of model cache files greater than mNumModelCache. { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; mModelCache.push_back({mTmpCache}); createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); mModelCache.pop_back(); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::GENERAL_FAILURE) { - ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT); - } - ASSERT_EQ(preparedModel, nullptr); + + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of model cache files smaller than mNumModelCache. if (mModelCache.size() > 0) { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; auto tmp = mModelCache.back(); mModelCache.pop_back(); createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); mModelCache.push_back(tmp); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::GENERAL_FAILURE) { - ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT); - } - ASSERT_EQ(preparedModel, nullptr); + + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of data cache files greater than mNumDataCache. { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; mDataCache.push_back({mTmpCache}); createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); mDataCache.pop_back(); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::GENERAL_FAILURE) { - ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT); - } - ASSERT_EQ(preparedModel, nullptr); + + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Test with number of data cache files smaller than mNumDataCache. if (mDataCache.size() > 0) { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; auto tmp = mDataCache.back(); mDataCache.pop_back(); createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); mDataCache.push_back(tmp); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::GENERAL_FAILURE) { - ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT); - } - ASSERT_EQ(preparedModel, nullptr); + + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } } @@ -719,13 +737,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Go through each handle in data cache, test with invalid access mode. @@ -741,13 +753,7 @@ TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) { // Execute and verify results. EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL); // Check if prepareModelFromCache fails. - preparedModel = nullptr; - ErrorStatus status; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - if (status != ErrorStatus::INVALID_ARGUMENT) { - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - } - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } } @@ -769,30 +775,23 @@ TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) { // Go through each handle in model cache, test with invalid access mode. for (uint32_t i = 0; i < mNumModelCache; i++) { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; modelCacheMode[i] = AccessMode::WRITE_ONLY; createCacheFds(mModelCache, modelCacheMode, &modelCache); createCacheFds(mDataCache, dataCacheMode, &dataCache); modelCacheMode[i] = AccessMode::READ_WRITE; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - ASSERT_EQ(preparedModel, nullptr); + + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } // Go through each handle in data cache, test with invalid access mode. for (uint32_t i = 0; i < mNumDataCache; i++) { - std::shared_ptr preparedModel = nullptr; - ErrorStatus status; std::vector modelCache, dataCache; dataCacheMode[i] = AccessMode::WRITE_ONLY; createCacheFds(mModelCache, modelCacheMode, &modelCache); createCacheFds(mDataCache, dataCacheMode, &dataCache); dataCacheMode[i] = AccessMode::READ_WRITE; - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); - ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE); - ASSERT_EQ(preparedModel, nullptr); + verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel); } } @@ -872,7 +871,8 @@ TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) { std::vector modelCache, dataCache; createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache); createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/nullptr); // The preparation may fail or succeed, but must not crash. If the preparation succeeds, // the prepared model must be executed with the correct result and not crash. @@ -933,7 +933,8 @@ TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) { // Spawn a thread to copy the cache content concurrently while preparing from cache. std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache)); - prepareModelFromCache(modelCache, dataCache, &preparedModel, &status); + prepareModelFromCache(modelCache, dataCache, &preparedModel, &status, + /*fallbackModel=*/nullptr); thread.join(); // The preparation may fail or succeed, but must not crash. If the preparation succeeds,