딥러닝 모델의 성능을 최적화하고 배포하는 것은 현대 AI 애플리케이션의 핵심 요소 중 하나입니다. 그 중에서도 모델을 효율적으로 관리하고 운영하는 것은 매우 중요합니다. Triton Inference Server는 이러한 요구에 부응하기 위해 설계된 강력한 도구입니다. 이번 시리즈에서는 Triton Inference Server에 대해 샅샅히 파헤처보는 글을 작성해보려고 합니다.
Triton Inference Server란?
Triton Inference Server는 NVIDIA에서 개발한 고성능 딥러닝 모델 추론(예측)을 위한 오픈 소스 솔루션입니다. 이 서버는 다양한 딥러닝 프레임워크에서 학습된 모델을 추론할 수 있도록 설계되었으며, 모델 배포와 관리를 위한 편리한 기능을 제공합니다.
이 서버는 TensorFlow, PyTorch, ONNX, TensorRT 등 다양한 프레임워크로 학습된 모델을 지원합니다. 또한 여러 가지 디바이스와 백엔드 엔진을 활용하여 최적의 추론 성능을 제공합니다. Triton은 CPU, GPU, TensorRT 및 TensorRT Inference Server, NVIDIA Deep Learning Accelerator (NVDLA) 및 NVIDIA TensorRT를 지원하여 사용자가 특정 환경 또는 요구 사항에 따라 최적의 추론 설정을 선택할 수 있습니다.
또한 Triton은 다양한 배포 환경에서 확장성을 제공합니다. Kubernetes, Docker 및 NVIDIA GPU Cloud (NGC)와 같은 컨테이너 오케스트레이션 플랫폼과 통합되어 있어 모델을 쉽게 배포하고 관리할 수 있습니다. 이를 통해 개발자 및 연구자들은 모델의 배포 및 관리 과정에서 생산성을 향상시키고, 효율성을 극대화할 수 있습니다.
Triton Architecture
위 그림은 Triton 추론 서버의 아키텍처를 보여줍니다. Model Repository는 Triton이 추론에 사용할 모델들의 파일 시스템 기반 저장소입니다. Inference Request는 HTTP/REST 또는 GRPC를 통해 서버로 전송되거나 C API를 통해 도착하고, 그런 다음 적절한 모델별 스케줄러로 라우팅됩니다.
Triton은 모델별로 구성할 수 있는 여러 가지 스케줄링 및 배치 알고리즘이 구현되어 있습니다. 각 모델의 스케줄러는 선택적으로 추론 요청을 배치 처리하고, 그런 다음 모델 유형에 해당하는 백엔드로 요청을 전달합니다. 백엔드는 배치 처리된 요청에서 제공된 입력을 사용하여 요청된 출력을 생성합니다. 그런 다음 출력이 반환됩니다.
Triton Inference Server repo 살펴보기
그러면 이번에는 실제로 Triton Inference Server가 어떻게 구현되어 있는지 확인해봅니다. Triton Inference Server는 core, server, backends 의 여러개의 repo로 분리되어 있습니다.
core repo에는 Triton의 핵심 기능을 구현하는 라이브러리의 소스 코드와 헤더가 들어 있습니다. core 라이브러리는 C API를 통해 직접 사용할 수 있습니다. 유용하게 사용하려면 핵심 라이브러리가 하나 이상의 백엔드와 쌍을 이루어야 합니다.
일반적으로 core 라이브러리를 자체적으로 구축하거나 사용하지 않고 tritonserver 실행 파일의 일부로 사용합니다. tritonserver 실행 파일은 server repo에 빌드됩니다.
다른 많은 소스들이 있지만 먼저 tritonserver의 실행을 담당하고 있는 server repo에서 main.cc의 main함수를 확인해보면 아래와 같이 구현되어 있는 것을 확인할 수 있습니다.
Triton Server의 옵션 파라미터를 파싱합니다.
triton::server::TritonParser tp;
try {
auto res = tp.Parse(argc, argv);
g_triton_params = res.first;
g_triton_params.CheckPortCollision();
}
catch (const triton::server::ParseException& pe) {
std::cerr << pe.what() << std::endl;
std::cerr << "Usage: tritonserver [options]" << std::endl;
std::cerr << tp.Usage() << std::endl;
exit(1);
}
triton::server::TritonServerParameters::ManagedTritonServerOptionPtr
triton_options(nullptr, TRITONSERVER_ServerOptionsDelete);
try {
triton_options = g_triton_params.BuildTritonServerOptions();
}
catch (const triton::server::ParseException& pe) {
std::cerr << "Failed to build Triton option:" << std::endl;
std::cerr << pe.what() << std::endl;
exit(1);
}
다음으로는 위에서 파싱한 parameter값을 참고하여 로깅에 대한 설정을 진행합니다.
#ifdef TRITON_ENABLE_LOGGING
// Initialize our own logging instance since it is used by GRPC and
// HTTP endpoints. This logging instance is separate from the one in
// libtritonserver so we must initialize explicitly.
LOG_ENABLE_INFO(g_triton_params.log_info_);
LOG_ENABLE_WARNING(g_triton_params.log_warn_);
LOG_ENABLE_ERROR(g_triton_params.log_error_);
LOG_SET_VERBOSE(g_triton_params.log_verbose_);
LOG_SET_FORMAT(g_triton_params.log_format_);
LOG_SET_OUT_FILE(g_triton_params.log_file_);
#endif // TRITON_ENABLE_LOGGING
이 후, Trace와 Shared Memory를 관리하는 Manager들을 선언하거나 초기화합니다.
// Trace manager.
triton::server::TraceManager* trace_manager;
// Manager for shared memory blocks.
auto shm_manager = std::make_shared<triton::server::SharedMemoryManager>();
다음으로는 API 서버들을 초기화 합니다.
// Create the server...
TRITONSERVER_Server* server_ptr = nullptr;
FAIL_IF_ERR(
TRITONSERVER_ServerNew(&server_ptr, triton_options.get()),
"creating server");
std::shared_ptr<TRITONSERVER_Server> server(
server_ptr, TRITONSERVER_ServerDelete);
API 서버들이 초기화 됐다면, Trace를 시작합니다. 이 때 StartTracing 함수에서 TraceManager를 초기화합니다.
// Configure and start tracing if specified on the command line.
if (!StartTracing(&trace_manager)) {
exit(1);
}
bool
StartTracing(triton::server::TraceManager** trace_manager)
{
*trace_manager = nullptr;
#ifdef TRITON_ENABLE_TRACING
TRITONSERVER_Error* err = triton::server::TraceManager::Create(
trace_manager, g_triton_params.trace_level_, g_triton_params.trace_rate_,
g_triton_params.trace_count_, g_triton_params.trace_log_frequency_,
g_triton_params.trace_filepath_, g_triton_params.trace_mode_,
g_triton_params.trace_config_map_);
if (err != nullptr) {
LOG_TRITONSERVER_ERROR(err, "failed to configure tracing");
if (*trace_manager != nullptr) {
delete (*trace_manager);
}
*trace_manager = nullptr;
return false;
}
#endif // TRITON_ENABLE_TRACING
return true;
}
이 후 SIGINT와 SIGTERM 사인을 핸들링 할 수 있도록 signal handler를 등록합니다. 이 기능은 SIGINT 또는 SIGTERM이 들어왔을 때, 서버를 gracefully 셧다운하도록 작성되어 있습니다.
// Trap SIGINT and SIGTERM to allow server to exit gracefully
TRITONSERVER_Error* signal_err = triton::server::RegisterSignalHandler();
if (signal_err != nullptr) {
LOG_TRITONSERVER_ERROR(signal_err, "failed to register signal handler");
exit(1);
}
다음으로는 위에서 초기화한 API 서버들과 trace manager, shm manager를 활용하여 Endpoint를 실행합니다.
// Start the HTTP, GRPC, and metrics endpoints.
if (!StartEndpoints(server, trace_manager, shm_manager)) {
exit(1);
}
그 중에서 HTTP와 관련된 Service만 살펴보면, StartHttpService 함수를 통해 HTTPAPIServer를 생성하고, 실행합니다.
#ifdef TRITON_ENABLE_HTTP
// Enable HTTP endpoints if requested...
if (g_triton_params.allow_http_) {
TRITONSERVER_Error* err =
StartHttpService(&g_http_service, server, trace_manager, shm_manager);
if (err != nullptr) {
LOG_TRITONSERVER_ERROR(err, "failed to start HTTP service");
return false;
}
}
#endif // TRITON_ENABLE_HTTP
#ifdef TRITON_ENABLE_HTTP
TRITONSERVER_Error*
StartHttpService(
std::unique_ptr<triton::server::HTTPServer>* service,
const std::shared_ptr<TRITONSERVER_Server>& server,
triton::server::TraceManager* trace_manager,
const std::shared_ptr<triton::server::SharedMemoryManager>& shm_manager)
{
TRITONSERVER_Error* err = triton::server::HTTPAPIServer::Create(
server, trace_manager, shm_manager, g_triton_params.http_port_,
g_triton_params.reuse_http_port_, g_triton_params.http_address_,
g_triton_params.http_forward_header_pattern_,
g_triton_params.http_thread_cnt_, g_triton_params.http_restricted_apis_,
service);
if (err == nullptr) {
err = (*service)->Start();
}
if (err != nullptr) {
service->reset();
}
return err;
}
#endif // TRITON_ENABLE_HTTP
HTTPServer::Start()
{
if (!worker_.joinable()) {
evbase_ = event_base_new();
htp_ = evhtp_new(evbase_, NULL);
evhtp_enable_flag(htp_, EVHTP_FLAG_ENABLE_NODELAY);
if (reuse_port_) {
evhtp_enable_flag(htp_, EVHTP_FLAG_ENABLE_REUSEPORT);
}
evhtp_set_gencb(htp_, HTTPServer::Dispatch, this);
evhtp_set_pre_accept_cb(htp_, HTTPServer::NewConnection, this);
evhtp_use_threads_wexit(htp_, NULL, NULL, thread_cnt_, NULL);
if (evhtp_bind_socket(htp_, address_.c_str(), port_, 1024) != 0) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNAVAILABLE,
(std::string("Socket '") + address_ + ":" + std::to_string(port_) +
"' already in use ")
.c_str());
}
// Set listening event for breaking event loop
evutil_socketpair(AF_UNIX, SOCK_STREAM, 0, fds_);
break_ev_ = event_new(evbase_, fds_[0], EV_READ, StopCallback, evbase_);
event_add(break_ev_, NULL);
worker_ = std::thread(event_base_loop, evbase_, 0);
return nullptr;
}
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_ALREADY_EXISTS, "HTTP server is already running.");
}
이 후에는 exit signal이 발생할 때까지 서버로직을 수행합니다. Triton Inference Server는 Model Managment mode가 POLL일 경우, model repository를 주기적으로 Polling하여 동기화합니다. 해당 로직을 아래 loop에서 처리합니다.
// Wait until a signal terminates the server...
while (!triton::server::signal_exiting_) {
// If enabled, poll the model repository to see if there have been
// any changes.
if (g_triton_params.repository_poll_secs_ > 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_ServerPollModelRepository(server_ptr),
"failed to poll model repository");
}
// Wait for the polling interval (or a long time if polling is not
// enabled). Will be woken if the server is exiting.
std::unique_lock<std::mutex> lock(triton::server::signal_exit_mu_);
std::chrono::seconds wait_timeout(
(g_triton_params.repository_poll_secs_ == 0)
? 3600
: g_triton_params.repository_poll_secs_);
triton::server::signal_exit_cv_.wait_for(lock, wait_timeout);
}
위 loop를 빠져나오기 위해서는 서버가 종료되어야 합니다. 따라서 이 후 로직들은 서버를 정리하는 코드들로 이루어져 있습니다.
// Stop the HTTP[, gRPC, and metrics] endpoints, and update exit timeout.
uint32_t exit_timeout_secs = g_triton_params.exit_timeout_secs_;
StopEndpoints(&exit_timeout_secs);
TRITONSERVER_ServerSetExitTimeout(server_ptr, exit_timeout_secs);
TRITONSERVER_Error* stop_err = TRITONSERVER_ServerStop(server_ptr);
// If unable to gracefully stop the server then Triton threads and
// state are potentially in an invalid state, so just exit
// immediately.
if (stop_err != nullptr) {
LOG_TRITONSERVER_ERROR(stop_err, "failed to stop server");
exit(1);
}
// Stop gRPC and metrics endpoints that do not yet support exit timeout.
StopEndpoints();
// Stop tracing.
StopTracing(&trace_manager);
#ifdef TRITON_ENABLE_ASAN
// Can invoke ASAN before exit though this is typically not very
// useful since there are many objects that are not yet destructed.
// __lsan_do_leak_check();
#endif // TRITON_ENABLE_ASAN
return 0;
Triton API Server ↔️ Model Backend
API 서버가 뜨는 방식은 알았지만, 이제 Model Backend랑 어떻게 통신을 하는지 궁금해집니다. 그러기 위해서 HTTPAPIServer의 Handle 함수를 파악해봅니다. HTTPAPIServer::Handle 함수는 url path에 따라 여러 역할을 가진 함수를 수행하도록 구성되어 있습니다. 여기서는 generate를 중심으로 파악해 보겠습니다.
void
HTTPAPIServer::Handle(evhtp_request_t* req)
{
...
std::string model_name, version, kind;
if (RE2::FullMatch(
std::string(req->uri->path->full), model_regex_, &model_name,
&version, &kind)) {
if (kind == "ready") {
// model ready
HandleModelReady(req, model_name, version);
return;
} else if (kind == "infer") {
// model infer
HandleInfer(req, model_name, version);
return;
} else if (kind == "generate") {
// text generation
HandleGenerate(req, model_name, version, false /* streaming */);
return;
} else if (kind == "generate_stream") {
// text generation (streaming)
HandleGenerate(req, model_name, version, true /* streaming */);
return;
} else if (kind == "config") {
// model configuration
HandleModelConfig(req, model_name, version);
return;
} else if (kind == "stats") {
// model statistics
HandleModelStats(req, model_name, version);
return;
} else if (kind == "trace/setting") {
// Trace with specific model, there is no specification on versioning
// so fall out and return bad request error if version is specified
if (version.empty()) {
HandleTrace(req, model_name);
return;
}
} else if (kind == "") {
// model metadata
HandleModelMetadata(req, model_name, version);
return;
}
}
HTTPAPIServer:HandleGenerate 함수는 아주 간략하게 요약해서 아래와 같은 처리를 수행합니다.
- trace를 초기화 합니다.
- request 정보를 request buffer에 기록합니다.
- TRITONSERVER_ServerInferAsync를 실행하여 Triton에서 Inference request를 수행하도록 요청합니다.
void
HTTPAPIServer::HandleGenerate(
evhtp_request_t* req, const std::string& model_name,
const std::string& model_version_str, bool streaming)
{
...
// If tracing is enabled see if this request should be traced.
TRITONSERVER_InferenceTrace* triton_trace = nullptr;
std::shared_ptr<TraceManager::Trace> trace =
StartTrace(req, model_name, &triton_trace);
...
triton::common::TritonJson::Value request;
RETURN_AND_CALLBACK_IF_ERR(EVRequestToJson(req, &request), error_callback);
RETURN_AND_CALLBACK_IF_ERR(
generate_request->ConvertGenerateRequest(
input_metadata, generate_request->RequestSchema(), request),
error_callback);
...
RETURN_AND_CALLBACK_IF_ERR(
TRITONSERVER_ServerInferAsync(server_.get(), irequest, triton_trace),
error_callback);
...
}
server repo에서 TRITONSERVER_ServerInferAsync로 요청된 것은 core repo의 아래 내용에 의해 처리됩니다.
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONSERVER_ServerInferAsync(
TRITONSERVER_Server* server,
TRITONSERVER_InferenceRequest* inference_request,
TRITONSERVER_InferenceTrace* trace)
{
tc::InferenceServer* lserver = reinterpret_cast<tc::InferenceServer*>(server);
tc::InferenceRequest* lrequest =
reinterpret_cast<tc::InferenceRequest*>(inference_request);
RETURN_IF_STATUS_ERROR(lrequest->PrepareForInference());
// Set the trace object in the request so that activity associated
// with the request can be recorded as the request flows through
// Triton.
if (trace != nullptr) {
#ifdef TRITON_ENABLE_TRACING
tc::InferenceTrace* ltrace = reinterpret_cast<tc::InferenceTrace*>(trace);
ltrace->SetModelName(lrequest->ModelName());
ltrace->SetModelVersion(lrequest->ActualModelVersion());
ltrace->SetRequestId(lrequest->Id());
lrequest->SetTrace(std::make_shared<tc::InferenceTraceProxy>(ltrace));
#else
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED, "inference tracing not supported");
#endif // TRITON_ENABLE_TRACING
}
// We wrap the request in a unique pointer to ensure that it flows
// through inferencing with clear ownership.
std::unique_ptr<tc::InferenceRequest> ureq(lrequest);
// Run inference...
tc::Status status = lserver->InferAsync(ureq);
// If there is an error then must explicitly release any trace
// object associated with the inference request above.
#ifdef TRITON_ENABLE_TRACING
if (!status.IsOk()) {
ureq->ReleaseTrace();
}
#endif // TRITON_ENABLE_TRACING
// If there is an error then ureq will still have 'lrequest' and we
// must release it from unique_ptr since the caller should retain
// ownership when there is error. If there is not an error then ureq
// == nullptr and so this release is a nop.
ureq.release();
RETURN_IF_STATUS_ERROR(status);
return nullptr; // Success
}
여기서 lserver->InferAsync는 끝까지 타고가다보면 InferenceRequest::Run을 호출하며, 이 함수에서는 request를 model에 Enqueue하는 동작을 수행합니다.
Status
InferenceRequest::Run(std::unique_ptr<InferenceRequest>& request)
{
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::PENDING));
auto status = request->model_raw_->Enqueue(request);
if (!status.IsOk()) {
LOG_STATUS_ERROR(
request->SetState(InferenceRequest::State::FAILED_ENQUEUE),
"Failed to set failed_enqueue state");
}
return status;
}
'AI > MLOps' 카테고리의 다른 글
LitServe 리뷰 (1) | 2024.08.31 |
---|---|
Triton Inference Server #5. Python Backend (0) | 2024.05.12 |
Triton Inference Server #4. Model Configuration (0) | 2024.05.12 |
Triton Inference Server #3. Model Management & Repository (0) | 2024.05.12 |
Triton Inference Server #2. 모델 스케쥴링 (0) | 2024.04.23 |