Welcome back to our "Model Serving with FastAPI" course! In this second lesson, we're advancing our journey to build a robust diamond price prediction API. In our previous lesson, we established the foundation by creating a basic FastAPI application with a root endpoint and a health check. Today, we'll take a significant step forward by integrating our machine learning model into the API and creating an endpoint for making predictions.
By the end of this lesson, you'll have a functional prediction endpoint that validates input data, processes it through your machine learning model, and returns diamond price predictions to users. This represents the core functionality of our model serving application. Let's begin by understanding how to properly handle and validate input data for our machine learning model!
Before diving into code, let's explore an essential component for building robust APIs: data validation. When serving machine learning models, ensuring that input data meets your expectations is crucial for preventing errors, providing clear feedback, and maintaining data integrity.
FastAPI leverages Pydantic for data validation, which validates data using Python type annotations. Here's how we can define the expected structure of diamond feature inputs:
In this code, we're creating a DiamondFeatures
class that inherits from Pydantic's BaseModel
. Each attribute has a type annotation and additional validation rules. For example, carat: float = Field(..., gt=0)
specifies that carat must be a positive floating-point number, and the ellipsis (...
) indicates the field is required. The description
parameter documents what each field represents in the auto-generated API docs.
We'll also create a model for our prediction response:
This response model ensures a consistent output format that includes both the prediction and the original input features for reference. When you use these Pydantic models, FastAPI automatically validates incoming requests, converts data to the correct types, and generates helpful error messages when validation fails.
Now that you have your data validation set up, let's tackle one of the key challenges of model serving: efficiently loading and managing machine learning models. Since models can be memory-intensive, you need a strategy to handle them effectively in your API.
Let's implement a dependency function that loads our model and preprocessor:
This function performs several critical tasks: it verifies the model files exist, loads them using the load_model_with_metadata
helper function (that we implemented in the previous course), and handles any errors by converting them to appropriate HTTP responses. When your API receives a prediction request, this function will provide the necessary model resources to process it.
With our model loading functions in place, let's look at how to efficiently make these models available to your API endpoints using FastAPI's powerful dependency injection system. Dependency injection allows you to:
- Share resources between endpoints
- Manage expensive resources efficiently
- Simplify endpoint function signatures
- Facilitate testing through dependency overrides
Here's how you'll use it for your prediction endpoint:
Notice the model_data: tuple = Depends(get_model)
parameter in our endpoint function. This tells FastAPI to call our get_model()
function before executing the endpoint and pass the result to the function. FastAPI even caches the dependency result during the request lifecycle, preventing redundant model loading if multiple endpoints use the same model.
This approach creates a clean separation of concerns: your model loading logic stays independent from your prediction logic, making your code more maintainable and testable.
Let's complete the implementation of the prediction endpoint:
This endpoint function:
- Receives validated diamond features through our Pydantic model.
- Gets the model and preprocessor through dependency injection.
- Converts the input to a
pandas
DataFrame (the format expected by most scikit-learn preprocessors). - Applies the same preprocessing steps used during training.
- Makes a prediction using the model.
- Returns the predicted price along with the original features.
The response_model=PredictionResponse
parameter ensures that FastAPI will validate and format our response according to the defined schema, maintaining consistency in your API responses.
A robust API needs comprehensive error handling to gracefully manage the various issues that can arise during model serving. Let's examine the error handling strategies implemented in your diamond price prediction API:
This multi-layered approach addresses three key types of errors:
-
Input validation errors: FastAPI and Pydantic automatically handle these, returning detailed 422 Unprocessable Entity responses when diamond features don't meet your specifications.
-
Model availability errors: Your dependency function checks if model files exist and returns a clear 404 Not Found response when they don't, helping API users understand they need to train a model first.
-
Processing errors: Your prediction endpoint catches any exceptions during preprocessing or prediction, returning a 500 Internal Server Error with details about what went wrong.
Congratulations! You've successfully transformed a basic FastAPI application into a functional machine learning service capable of predicting diamond prices. You've implemented data validation with Pydantic, created efficient model loading strategies through dependency injection, built a prediction endpoint that processes inputs and returns results, and established robust error handling to make your API reliable and user-friendly.
Now it's time to test what you learned with some hands-on practice. Happy coding!
