From 20f28a24e908d54f4708ad17943154fb61a4c770 Mon Sep 17 00:00:00 2001 From: Michael Butler Date: Fri, 26 Apr 2019 17:46:08 -0700 Subject: [PATCH] Add validation tests for NNAPI Burst serialized format This CL adds the following two types of validation tests on the NNAPI Burst serialized format: (1) it directly modifies the serialized data (invalidating it) to ensure that vendor driver services properly validates the serialized request (2) it ensures that vendor driver services properly fail when the result channel is not large enough to return the data This CL additionally includes miscellaneous cleanups: (1) having a generic "validateEverything" function (2) moving the "prepareModel" function that's common across validateRequest and validateBurst to a common area Fixes: 129779280 Bug: 129157135 Test: mma Test: VtsHalNeuralnetworksV1_2TargetTest (with sample-all) Change-Id: Ib90fe7f662824de17db5a254a8c501855e45f6bd --- .../vts/functional/VtsHalNeuralnetworks.cpp | 5 + .../1.0/vts/functional/VtsHalNeuralnetworks.h | 7 +- .../vts/functional/VtsHalNeuralnetworks.cpp | 5 + .../1.1/vts/functional/VtsHalNeuralnetworks.h | 7 +- neuralnetworks/1.2/vts/functional/Android.bp | 4 + .../1.2/vts/functional/ValidateBurst.cpp | 333 ++++++++++++++++++ .../1.2/vts/functional/ValidateRequest.cpp | 61 +--- .../vts/functional/VtsHalNeuralnetworks.cpp | 73 ++++ .../1.2/vts/functional/VtsHalNeuralnetworks.h | 10 +- 9 files changed, 440 insertions(+), 65 deletions(-) create mode 100644 neuralnetworks/1.2/vts/functional/ValidateBurst.cpp diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp index 88830574da..31638c425f 100644 --- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp @@ -68,6 +68,11 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + validateRequests(model, request); +} + } // namespace functional } // namespace vts } // namespace V1_0 diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h index d4c114d3a2..559d678ea1 100644 --- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h @@ -63,8 +63,11 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const Model& model, const std::vector& request); }; // Tag for the generated tests diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp index 224a51d149..11fa693ddc 100644 --- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp @@ -68,6 +68,11 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + validateRequests(model, request); +} + } // namespace functional } // namespace vts } // namespace V1_1 diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h index 1c8c0e18cb..cea2b54c2d 100644 --- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h @@ -72,8 +72,11 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const Model& model, const std::vector& request); }; // Tag for the generated tests diff --git a/neuralnetworks/1.2/vts/functional/Android.bp b/neuralnetworks/1.2/vts/functional/Android.bp index 891b414480..6c26820b27 100644 --- a/neuralnetworks/1.2/vts/functional/Android.bp +++ b/neuralnetworks/1.2/vts/functional/Android.bp @@ -20,6 +20,7 @@ cc_test { defaults: ["VtsHalNeuralNetworksTargetTestDefaults"], srcs: [ "GeneratedTestsV1_0.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -32,6 +33,7 @@ cc_test { defaults: ["VtsHalNeuralNetworksTargetTestDefaults"], srcs: [ "GeneratedTestsV1_1.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -46,6 +48,7 @@ cc_test { "BasicTests.cpp", "CompilationCachingTests.cpp", "GeneratedTests.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE" @@ -58,6 +61,7 @@ cc_test { srcs: [ "BasicTests.cpp", "GeneratedTests.cpp", + "ValidateBurst.cpp", ], cflags: [ "-DNN_TEST_DYNAMIC_OUTPUT_SHAPE", diff --git a/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp new file mode 100644 index 0000000000..386c141f80 --- /dev/null +++ b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "neuralnetworks_hidl_hal_test" + +#include "VtsHalNeuralnetworks.h" + +#include "Callbacks.h" +#include "ExecutionBurstController.h" +#include "ExecutionBurstServer.h" +#include "TestHarness.h" +#include "Utils.h" + +#include + +namespace android { +namespace hardware { +namespace neuralnetworks { +namespace V1_2 { +namespace vts { +namespace functional { + +using ::android::nn::ExecutionBurstController; +using ::android::nn::RequestChannelSender; +using ::android::nn::ResultChannelReceiver; +using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback; + +constexpr size_t kExecutionBurstChannelLength = 1024; +constexpr size_t kExecutionBurstChannelSmallLength = 8; + +///////////////////////// UTILITY FUNCTIONS ///////////////////////// + +static bool badTiming(Timing timing) { + return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; +} + +static void createBurst(const sp& preparedModel, const sp& callback, + std::unique_ptr* sender, + std::unique_ptr* receiver, + sp* context) { + ASSERT_NE(nullptr, preparedModel.get()); + ASSERT_NE(nullptr, sender); + ASSERT_NE(nullptr, receiver); + ASSERT_NE(nullptr, context); + + // create FMQ objects + auto [fmqRequestChannel, fmqRequestDescriptor] = + RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); + auto [fmqResultChannel, fmqResultDescriptor] = + ResultChannelReceiver::create(kExecutionBurstChannelLength, /*blocking=*/true); + ASSERT_NE(nullptr, fmqRequestChannel.get()); + ASSERT_NE(nullptr, fmqResultChannel.get()); + ASSERT_NE(nullptr, fmqRequestDescriptor); + ASSERT_NE(nullptr, fmqResultDescriptor); + + // configure burst + ErrorStatus errorStatus; + sp burstContext; + const Return ret = preparedModel->configureExecutionBurst( + callback, *fmqRequestDescriptor, *fmqResultDescriptor, + [&errorStatus, &burstContext](ErrorStatus status, const sp& context) { + errorStatus = status; + burstContext = context; + }); + ASSERT_TRUE(ret.isOk()); + ASSERT_EQ(ErrorStatus::NONE, errorStatus); + ASSERT_NE(nullptr, burstContext.get()); + + // return values + *sender = std::move(fmqRequestChannel); + *receiver = std::move(fmqResultChannel); + *context = burstContext; +} + +static void createBurstWithResultChannelLength( + const sp& preparedModel, + std::shared_ptr* controller, size_t resultChannelLength) { + ASSERT_NE(nullptr, preparedModel.get()); + ASSERT_NE(nullptr, controller); + + // create FMQ objects + auto [fmqRequestChannel, fmqRequestDescriptor] = + RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); + auto [fmqResultChannel, fmqResultDescriptor] = + ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true); + ASSERT_NE(nullptr, fmqRequestChannel.get()); + ASSERT_NE(nullptr, fmqResultChannel.get()); + ASSERT_NE(nullptr, fmqRequestDescriptor); + ASSERT_NE(nullptr, fmqResultDescriptor); + + // configure burst + sp callback = new ExecutionBurstCallback(); + ErrorStatus errorStatus; + sp burstContext; + const Return ret = preparedModel->configureExecutionBurst( + callback, *fmqRequestDescriptor, *fmqResultDescriptor, + [&errorStatus, &burstContext](ErrorStatus status, const sp& context) { + errorStatus = status; + burstContext = context; + }); + ASSERT_TRUE(ret.isOk()); + ASSERT_EQ(ErrorStatus::NONE, errorStatus); + ASSERT_NE(nullptr, burstContext.get()); + + // return values + *controller = std::make_shared( + std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback); +} + +// Primary validation function. This function will take a valid serialized +// request, apply a mutation to it to invalidate the serialized request, then +// pass it to interface calls that use the serialized request. Note that the +// serialized request here is passed by value, and any mutation to the +// serialized request does not leave this function. +static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::string& message, std::vector serialized, + const std::function*)>& mutation) { + mutation(&serialized); + + // skip if packet is too large to send + if (serialized.size() > kExecutionBurstChannelLength) { + return; + } + + SCOPED_TRACE(message); + + // send invalid packet + sender->sendPacket(serialized); + + // receive error + auto results = receiver->getBlocking(); + ASSERT_TRUE(results.has_value()); + const auto [status, outputShapes, timing] = std::move(*results); + EXPECT_NE(ErrorStatus::NONE, status); + EXPECT_EQ(0u, outputShapes.size()); + EXPECT_TRUE(badTiming(timing)); +} + +static std::vector createUniqueDatum() { + const FmqRequestDatum::PacketInformation packetInformation = { + /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10, + /*.numberOfPools=*/10}; + const FmqRequestDatum::OperandInformation operandInformation = { + /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10}; + const int32_t invalidPoolIdentifier = std::numeric_limits::max(); + std::vector unique(7); + unique[0].packetInformation(packetInformation); + unique[1].inputOperandInformation(operandInformation); + unique[2].inputOperandDimensionValue(0); + unique[3].outputOperandInformation(operandInformation); + unique[4].outputOperandDimensionValue(0); + unique[5].poolIdentifier(invalidPoolIdentifier); + unique[6].measureTiming(MeasureTiming::YES); + return unique; +} + +static const std::vector& getUniqueDatum() { + static const std::vector unique = createUniqueDatum(); + return unique; +} + +///////////////////////// REMOVE DATUM //////////////////////////////////// + +static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + for (size_t index = 0; index < serialized.size(); ++index) { + const std::string message = "removeDatum: removed datum at index " + std::to_string(index); + validate(sender, receiver, message, serialized, + [index](std::vector* serialized) { + serialized->erase(serialized->begin() + index); + }); + } +} + +///////////////////////// ADD DATUM //////////////////////////////////// + +static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + const std::vector& extra = getUniqueDatum(); + for (size_t index = 0; index <= serialized.size(); ++index) { + for (size_t type = 0; type < extra.size(); ++type) { + const std::string message = "addDatum: added datum type " + std::to_string(type) + + " at index " + std::to_string(index); + validate(sender, receiver, message, serialized, + [index, type, &extra](std::vector* serialized) { + serialized->insert(serialized->begin() + index, extra[type]); + }); + } + } +} + +///////////////////////// MUTATE DATUM //////////////////////////////////// + +static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) { + using Discriminator = FmqRequestDatum::hidl_discriminator; + + const bool differentValues = (lhs != rhs); + const bool sameSumType = (lhs.getDiscriminator() == rhs.getDiscriminator()); + const auto discriminator = rhs.getDiscriminator(); + const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue || + discriminator == Discriminator::outputOperandDimensionValue); + + return differentValues && !(sameSumType && isDimensionValue); +} + +static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, + const std::vector& serialized) { + const std::vector& change = getUniqueDatum(); + for (size_t index = 0; index < serialized.size(); ++index) { + for (size_t type = 0; type < change.size(); ++type) { + if (interestingCase(serialized[index], change[type])) { + const std::string message = "mutateDatum: changed datum at index " + + std::to_string(index) + " to datum type " + + std::to_string(type); + validate(sender, receiver, message, serialized, + [index, type, &change](std::vector* serialized) { + (*serialized)[index] = change[type]; + }); + } + } + } +} + +///////////////////////// BURST VALIATION TESTS //////////////////////////////////// + +static void validateBurstSerialization(const sp& preparedModel, + const std::vector& requests) { + // create burst + std::unique_ptr sender; + std::unique_ptr receiver; + sp callback = new ExecutionBurstCallback(); + sp context; + ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); + ASSERT_NE(nullptr, sender.get()); + ASSERT_NE(nullptr, receiver.get()); + ASSERT_NE(nullptr, context.get()); + + // validate each request + for (const Request& request : requests) { + // load memory into callback slots + std::vector keys(request.pools.size()); + for (size_t i = 0; i < keys.size(); ++i) { + keys[i] = reinterpret_cast(&request.pools[i]); + } + const std::vector slots = callback->getSlots(request.pools, keys); + + // ensure slot std::numeric_limits::max() doesn't exist (for + // subsequent slot validation testing) + const auto maxElement = std::max_element(slots.begin(), slots.end()); + ASSERT_NE(slots.end(), maxElement); + ASSERT_NE(std::numeric_limits::max(), *maxElement); + + // serialize the request + const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots); + + // validations + removeDatumTest(sender.get(), receiver.get(), serialized); + addDatumTest(sender.get(), receiver.get(), serialized); + mutateDatumTest(sender.get(), receiver.get(), serialized); + } +} + +static void validateBurstFmqLength(const sp& preparedModel, + const std::vector& requests) { + // create regular burst + std::shared_ptr controllerRegular; + ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerRegular, + kExecutionBurstChannelLength)); + ASSERT_NE(nullptr, controllerRegular.get()); + + // create burst with small output channel + std::shared_ptr controllerSmall; + ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerSmall, + kExecutionBurstChannelSmallLength)); + ASSERT_NE(nullptr, controllerSmall.get()); + + // validate each request + for (const Request& request : requests) { + // load memory into callback slots + std::vector keys(request.pools.size()); + for (size_t i = 0; i < keys.size(); ++i) { + keys[i] = reinterpret_cast(&request.pools[i]); + } + + // collect serialized result by running regular burst + const auto [status1, outputShapes1, timing1] = + controllerRegular->compute(request, MeasureTiming::NO, keys); + + // skip test if synchronous output isn't useful + const std::vector serialized = + ::android::nn::serialize(status1, outputShapes1, timing1); + if (status1 != ErrorStatus::NONE || + serialized.size() <= kExecutionBurstChannelSmallLength) { + continue; + } + + // by this point, execution should fail because the result channel isn't + // large enough to return the serialized result + const auto [status2, outputShapes2, timing2] = + controllerSmall->compute(request, MeasureTiming::NO, keys); + EXPECT_NE(ErrorStatus::NONE, status2); + EXPECT_EQ(0u, outputShapes2.size()); + EXPECT_TRUE(badTiming(timing2)); + } +} + +///////////////////////////// ENTRY POINT ////////////////////////////////// + +void ValidationTest::validateBurst(const sp& preparedModel, + const std::vector& requests) { + ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests)); + ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests)); +} + +} // namespace functional +} // namespace vts +} // namespace V1_2 +} // namespace neuralnetworks +} // namespace hardware +} // namespace android diff --git a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp index 870d01748a..9703c2d765 100644 --- a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp +++ b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp @@ -35,9 +35,7 @@ namespace vts { namespace functional { using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback; -using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback; using ::android::hidl::memory::V1_0::IMemory; -using HidlToken = hidl_array(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; using test_helper::for_all; using test_helper::MixedTyped; using test_helper::MixedTypedExample; @@ -48,55 +46,6 @@ static bool badTiming(Timing timing) { return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; } -static void createPreparedModel(const sp& device, const Model& model, - sp* preparedModel) { - ASSERT_NE(nullptr, preparedModel); - - // see if service can handle model - bool fullySupportsModel = false; - Return supportedOpsLaunchStatus = device->getSupportedOperations_1_2( - model, [&fullySupportsModel](ErrorStatus status, const hidl_vec& supported) { - ASSERT_EQ(ErrorStatus::NONE, status); - ASSERT_NE(0ul, supported.size()); - fullySupportsModel = - std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; }); - }); - ASSERT_TRUE(supportedOpsLaunchStatus.isOk()); - - // launch prepare model - sp preparedModelCallback = new PreparedModelCallback(); - ASSERT_NE(nullptr, preparedModelCallback.get()); - Return prepareLaunchStatus = device->prepareModel_1_2( - model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec(), - hidl_vec(), HidlToken(), preparedModelCallback); - ASSERT_TRUE(prepareLaunchStatus.isOk()); - ASSERT_EQ(ErrorStatus::NONE, static_cast(prepareLaunchStatus)); - - // retrieve prepared model - preparedModelCallback->wait(); - ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); - *preparedModel = getPreparedModel_1_2(preparedModelCallback); - - // The getSupportedOperations_1_2 call returns a list of operations that are - // guaranteed not to fail if prepareModel_1_2 is called, and - // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed. - // If a driver has any doubt that it can prepare an operation, it must - // return false. So here, if a driver isn't sure if it can support an - // operation, but reports that it successfully prepared the model, the test - // can continue. - if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) { - ASSERT_EQ(nullptr, preparedModel->get()); - LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot " - "prepare model that it does not support."; - std::cout << "[ ] Unable to test Request validation because vendor service " - "cannot prepare model that it does not support." - << std::endl; - return; - } - ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus); - ASSERT_NE(nullptr, preparedModel->get()); -} - // Primary validation function. This function will take a valid request, apply a // mutation to it to invalidate the request, then pass it to interface calls // that use the request. Note that the request here is passed by value, and any @@ -316,14 +265,8 @@ std::vector createRequests(const std::vector& exampl return requests; } -void ValidationTest::validateRequests(const Model& model, const std::vector& requests) { - // create IPreparedModel - sp preparedModel; - ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel)); - if (preparedModel == nullptr) { - return; - } - +void ValidationTest::validateRequests(const sp& preparedModel, + const std::vector& requests) { // validate each request for (const Request& request : requests) { removeInputTest(preparedModel, request); diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp index 4728c28e87..93182f1da2 100644 --- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp +++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp @@ -18,6 +18,10 @@ #include "VtsHalNeuralnetworks.h" +#include + +#include "Callbacks.h" + namespace android { namespace hardware { namespace neuralnetworks { @@ -25,6 +29,61 @@ namespace V1_2 { namespace vts { namespace functional { +using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback; +using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback; +using HidlToken = hidl_array(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; +using V1_1::ExecutionPreference; + +// internal helper function +static void createPreparedModel(const sp& device, const Model& model, + sp* preparedModel) { + ASSERT_NE(nullptr, preparedModel); + + // see if service can handle model + bool fullySupportsModel = false; + Return supportedOpsLaunchStatus = device->getSupportedOperations_1_2( + model, [&fullySupportsModel](ErrorStatus status, const hidl_vec& supported) { + ASSERT_EQ(ErrorStatus::NONE, status); + ASSERT_NE(0ul, supported.size()); + fullySupportsModel = std::all_of(supported.begin(), supported.end(), + [](bool valid) { return valid; }); + }); + ASSERT_TRUE(supportedOpsLaunchStatus.isOk()); + + // launch prepare model + sp preparedModelCallback = new PreparedModelCallback(); + ASSERT_NE(nullptr, preparedModelCallback.get()); + Return prepareLaunchStatus = device->prepareModel_1_2( + model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec(), + hidl_vec(), HidlToken(), preparedModelCallback); + ASSERT_TRUE(prepareLaunchStatus.isOk()); + ASSERT_EQ(ErrorStatus::NONE, static_cast(prepareLaunchStatus)); + + // retrieve prepared model + preparedModelCallback->wait(); + ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); + *preparedModel = getPreparedModel_1_2(preparedModelCallback); + + // The getSupportedOperations_1_2 call returns a list of operations that are + // guaranteed not to fail if prepareModel_1_2 is called, and + // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed. + // If a driver has any doubt that it can prepare an operation, it must + // return false. So here, if a driver isn't sure if it can support an + // operation, but reports that it successfully prepared the model, the test + // can continue. + if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) { + ASSERT_EQ(nullptr, preparedModel->get()); + LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot " + "prepare model that it does not support."; + std::cout << "[ ] Unable to test Request validation because vendor service " + "cannot prepare model that it does not support." + << std::endl; + return; + } + ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus); + ASSERT_NE(nullptr, preparedModel->get()); +} + // A class for test environment setup NeuralnetworksHidlEnvironment::NeuralnetworksHidlEnvironment() {} @@ -68,6 +127,20 @@ void NeuralnetworksHidlTest::TearDown() { ::testing::VtsHalHidlTargetTestBase::TearDown(); } +void ValidationTest::validateEverything(const Model& model, const std::vector& request) { + validateModel(model); + + // create IPreparedModel + sp preparedModel; + ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel)); + if (preparedModel == nullptr) { + return; + } + + validateRequests(preparedModel, request); + validateBurst(preparedModel, request); +} + sp getPreparedModel_1_2( const sp& callback) { sp preparedModelV1_0 = callback->getPreparedModel(); diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h index 404eec06db..36e73a4fb0 100644 --- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h +++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h @@ -72,8 +72,14 @@ class NeuralnetworksHidlTest : public ::testing::VtsHalHidlTargetTestBase { // Tag for the validation tests class ValidationTest : public NeuralnetworksHidlTest { protected: - void validateModel(const Model& model); - void validateRequests(const Model& model, const std::vector& request); + void validateEverything(const Model& model, const std::vector& request); + + private: + void validateModel(const Model& model); + void validateRequests(const sp& preparedModel, + const std::vector& requests); + void validateBurst(const sp& preparedModel, + const std::vector& requests); }; // Tag for the generated tests