from typing import List
from grpclib.client import Channel
from syncer import sync
from typing import Optional
from unacatlib.query.data_reference_list_plugin import wrap_data_references
from unacatlib.unacast.unatype import FilterClause
from unacatlib.unacast.catalog.v3 import (
CatalogQueryStub,
Options,
GetDataReferenceRequest,
ListDataReferencesRequest,
ListRecordsRequest,
ListRecordsResponse,
OrderBy,
OrderDirection,
SearchFieldValuesRequest,
SearchFieldValuesResponse,
TrimToPeriodOfExistence,
Field,
BoundingBox,
)
from unacatlib.unacast.catalog.v3alpha import (
CatalogQueryStub as CatalogQueryAlphaStub,
FileDownloadRequest,
FileDownloadResponse,
)
from unacatlib.query.proto_wrappers import (
PaginatedListRecordsResponse,
PaginatedSearchFieldValuesResponse,
GetDataReferenceResponse,
ListDataReferencesResponse,
ListRecordsStatisticsResponse,
)
import os
SERVER_ADDRESS = "catalog.unacastapis.com"
PORT = 443
REQUEST_TAGS = {"source": "unacat-py"}
LIST_RECORDS_PAGE_SIZE = 1000
SEARCH_FIELD_VALUES_PAGE_SIZE = 1000
ABSOLUTE_MAX_NUMBER_OF_RECORDS = 100_000
OVER_ABSOLUTE_MAX_NUMBER_OF_RECORDS_MESSAGE = f"If you want to retrieve more than {ABSOLUTE_MAX_NUMBER_OF_RECORDS:,}, we suggest using the file export endpoint or setting up a file delivery. Please contact Unacast support if you need assistance at support@unacast.com."
__all__ = [
"QueryClient",
"FilterClause",
"TrimToPeriodOfExistence",
"OrderBy",
"OrderDirection",
"BoundingBox",
]
[docs]
class QueryClient(object):
"""
A client for querying the Unacast Catalog API.
If no API key is provided, it will attempt to read the API key from the UNACAST_API_KEY environment variable.
Example:
>>> from unacatlib import QueryClient
>>> client = QueryClient() # API key can be provided as an argument or set via UNACAST_API_KEY env variable
>>> response = client.list_records("ml_visitation.foot_traffic_month", limit=100)
>>> df = response.records.to_df()
>>> print(df.head())
"""
[docs]
def __init__(
self,
api_key: str = "",
billing_context: str = "",
server_address: str = SERVER_ADDRESS,
port: int = PORT,
):
self.api_key = api_key or os.environ.get("UNACAST_API_KEY", "")
if not self.api_key:
raise ValueError(
"No API key provided. Please provide an API key or set the UNACAST_API_KEY environment variable."
)
self.billing_context = billing_context
self.server_address = server_address
self.port = port
metadata = [("authorization", "Bearer " + self.api_key)]
self.channel = Channel(host=self.server_address, port=self.port, ssl=True)
self.query_service = CatalogQueryStub(self.channel, metadata=metadata)
self.query_service_alpha = CatalogQueryAlphaStub(
self.channel, metadata=metadata
)
# TODO: do we need properties for ssl as the old client has?
[docs]
def get_data_reference(self, data_reference_name: str) -> GetDataReferenceResponse:
"""
Get a specific data reference by name.
Args:
data_reference_name: The name of the data reference to get
Returns:
GetDataReferenceResponse with data_reference.fields as FieldDefinitionList
"""
request = GetDataReferenceRequest(
data_reference_name=data_reference_name,
billing_context=self.billing_context,
)
response: GetDataReferenceResponse = sync(
self.query_service.get_data_reference(request)
)
return response
[docs]
def list_data_references(self) -> ListDataReferencesResponse:
"""
List available data references.
Returns:
ListDataReferencesResponse with data_references as DataReferenceList
"""
request = ListDataReferencesRequest(billing_context=self.billing_context)
response: ListDataReferencesResponse = sync(
self.query_service.list_data_references(request)
)
wrap_data_references(response)
return response
[docs]
def list_records(
self,
data_reference_name: str,
fields: Optional[List[str]] = None,
filters: Optional[List[FilterClause]] = None,
options: Optional[Options] = None,
limit: Optional[int] = None,
) -> PaginatedListRecordsResponse:
"""
List records from a data reference with automatic pagination handling.
Args:
data_reference_name: The name of the data reference to query. Use `client.list_data_references().data_references.to_df()` to get the available `data_reference_name`.
fields: Optional list of field names to include in the response. If None, all default fields are returned. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `fields`.
filters: Optional list of filters to apply. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `field` and `operator` to use in the `FilterClause`. Use the `client.search_field_values(data_reference_name, field).values` endpoint to get the available `values` to use in the `FilterClause`.
options: Optional query options
limit: Optional limit on the number of records to return
Returns:
PaginatedListRecordsResponse containing all records and field definitions,
inheriting all ListRecordsResponse functionality but without exposing pagination details
Example:
>>> from unacatlib import QueryClient, FilterClause
>>> client = QueryClient() # API key can be provided as an argument or set via UNACAST_API_KEY env variable
>>> response = client.list_records(
... "ml_visitation.foot_traffic_month",
... fields=["location_id", "brands", "street_address", "observation_start_date", "observation_end_date", "visits_sum"], # Optional: specify fields to return
... filters=[
... FilterClause(
... field_name="brands",
... operator="==",
... value="Target"
... )
... ],
... limit=500
... )
>>> df = response.records.to_df()
>>> print(df.head())
"""
if limit and limit > ABSOLUTE_MAX_NUMBER_OF_RECORDS:
raise ValueError(
f"The limit is {limit:,}, which is more than the max number of records: {ABSOLUTE_MAX_NUMBER_OF_RECORDS:,}. Please use a smaller limit. {OVER_ABSOLUTE_MAX_NUMBER_OF_RECORDS_MESSAGE}"
)
all_records = []
page_token = None
field_definitions = None
total_size = 0
page_size = (
limit
if limit and limit < LIST_RECORDS_PAGE_SIZE
else LIST_RECORDS_PAGE_SIZE
)
field_objects = [Field(name=field) for field in fields] if fields else None
while True:
if options is None:
options = Options()
request = ListRecordsRequest(
data_reference_name=data_reference_name,
fields=field_objects,
filters=filters,
options=options,
billing_context=self.billing_context,
page_token=page_token,
page_size=page_size,
)
response: ListRecordsResponse = sync(
self.query_service.list_records(request)
)
# Store field definitions and total size from first response
if field_definitions is None:
field_definitions = response.field_definitions
total_size = response.total_size
# Accumulate records from this page
if response.records:
all_records.extend(response.records)
# Check if we've reached the max number of records
if not limit and ABSOLUTE_MAX_NUMBER_OF_RECORDS <= response.total_size:
raise ValueError(
f"The response contains {response.total_size:,} records, which is more than the max number of records: {ABSOLUTE_MAX_NUMBER_OF_RECORDS:,}. Please use limit, more specific filters or options to reduce the number of records. {OVER_ABSOLUTE_MAX_NUMBER_OF_RECORDS_MESSAGE}"
)
if limit and len(all_records) >= limit:
if limit < len(all_records):
all_records = all_records[:limit]
break
# Check if there are more pages
if not response.next_page_token:
break
page_token = response.next_page_token
# Create a new response without pagination details
resp = PaginatedListRecordsResponse(
records=all_records,
field_definitions=field_definitions,
total_size=total_size,
)
return resp
# Create an alias for list_records
query = list_records
"""Alias for list_records method. Both methods have identical functionality."""
[docs]
def search_field_values(
self,
data_reference_name: str,
field: str,
term: Optional[str] = None,
filters: Optional[List[FilterClause]] = None,
options: Optional[Options] = None,
limit: Optional[int] = None,
) -> PaginatedSearchFieldValuesResponse:
"""
Search for distinct values of a field.
Args:
data_reference_name: The name of the data reference to query. Use `client.list_data_references().data_references.to_df()` to get the available `data_reference_name`.
field: The field to search values for. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `field`.
term: Optional search term to filter values
filters: Optional list of filters to apply. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `field` and `operator` to use in the `FilterClause`. Use the `client.search_field_values(data_reference_name, field).values` endpoint to get the available `values` to use in the `FilterClause`.
options: Optional query options
limit: Optional limit on the number of values to return
Returns:
List of distinct values
Example:
>>> from unacatlib import QueryClient
>>> client = QueryClient() # API key can be provided as an argument or set via UNACAST_API_KEY env variable
>>> response = client.search_field_values("ml_visitation.foot_traffic_month", "brands", "Walm")
>>> print(response)
"""
if limit and limit > ABSOLUTE_MAX_NUMBER_OF_RECORDS:
raise ValueError(
f"The limit is {limit:,}, which is more than the max number of values: {ABSOLUTE_MAX_NUMBER_OF_RECORDS:,}. Please use a smaller limit. {OVER_ABSOLUTE_MAX_NUMBER_OF_RECORDS_MESSAGE}"
)
all_values = []
page_token = None
field_definitions = None
# total_size = 0
page_size = (
limit
if limit and limit < SEARCH_FIELD_VALUES_PAGE_SIZE
else SEARCH_FIELD_VALUES_PAGE_SIZE
)
field_object = Field(name=field)
while True:
request = SearchFieldValuesRequest(
data_reference_name=data_reference_name,
field=field_object,
term=term,
filters=filters,
options=options,
billing_context=self.billing_context,
page_token=page_token,
page_size=page_size,
)
response: SearchFieldValuesResponse = sync(
self.query_service.search_field_values(request)
)
# TODO: change this to be using total_size instead. When changing remember to also check if limit is set before raising this error.
if len(all_values) > ABSOLUTE_MAX_NUMBER_OF_RECORDS:
raise ValueError(
f"The response contains {len(all_values):,} values, which is more than the max number of values: {ABSOLUTE_MAX_NUMBER_OF_RECORDS:,}. Please use limit, more specific filters or options to reduce the number of values. {OVER_ABSOLUTE_MAX_NUMBER_OF_RECORDS_MESSAGE}"
)
# Store field definitions and total size from first response
if field_definitions is None:
field_definitions = response.field_definition
# total_size = response.total_size
# Accumulate values from this page
if response.values:
all_values.extend(response.values)
if limit and len(all_values) >= limit:
if limit < len(all_values):
all_values = all_values[:limit]
break
# Check if there are more pages
if not response.next_page_token:
break
page_token = response.next_page_token
return PaginatedSearchFieldValuesResponse(
values=all_values,
field_definition=field_definitions,
total_size=len(all_values),
)
[docs]
def file_download(
self,
data_reference_name: str,
fields: Optional[List[Field]] = None,
filters: Optional[List[FilterClause]] = None,
options: Optional[Options] = None,
) -> FileDownloadResponse:
"""
Download a file from a data reference.
Args:
data_reference_name: The name of the data reference to query. Use `client.list_data_references().data_references.to_df()` to get the available `data_reference_name`.
fields: Optional list of field names to include in the download. If None, all default fields are returned. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `fields`.
filters: Optional list of filters to apply. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `field` and `operator` to use in the `FilterClause`. Use the `client.search_field_values(data_reference_name, field).values` endpoint to get the available `values` to use in the `FilterClause`.
options: Optional query options
Returns:
FileDownloadResponse containing the download information
Note:
This method is currently in alpha and may change.
"""
print("Warning: Endpoint is currently in alpha and may change.")
request = FileDownloadRequest(
data_reference_name=data_reference_name,
options=options,
billing_context=self.billing_context,
filters=filters,
fields=fields,
)
response: FileDownloadResponse = sync(
self.query_service_alpha.file_download(request)
)
return response
[docs]
def list_records_statistics(
self,
data_reference_name: str,
filters: Optional[List[FilterClause]] = None,
options: Optional[Options] = None,
) -> ListRecordsStatisticsResponse:
"""
Get statistics for records in a data reference.
Args:
data_reference_name: The name of the data reference to query. Use `client.list_data_references().data_references.to_df()` to get the available `data_reference_name`.
filters: Optional list of filters to apply to the records before computing statistics. Use `client.get_data_reference(data_reference_name).data_reference.fields.to_df()` to get the available `field` and `operator` to use in the `FilterClause`. Use the `client.search_field_values(data_reference_name, field).values` endpoint to get the available `values` to use in the `FilterClause`.
options: Optional query options
Returns:
ListRecordsStatisticsResponse containing statistical information about the records
Example:
>>> client = QueryClient()
>>> stats = client.list_records_statistics("ml_visitation.foot_traffic_month")
>>> print(stats)
Note:
This method is currently in alpha and may change.
"""
print("Warning: Endpoint is currently in alpha and may change.")
request = ListRecordsRequest(
data_reference_name=data_reference_name,
filters=filters,
options=options,
billing_context=self.billing_context,
)
response: ListRecordsStatisticsResponse = sync(
self.query_service_alpha.list_records_statistics(request)
)
return response
# Create an alias for list_records_statistics
query_statistics = list_records_statistics
"""Alias for list_records_statistics method. Both methods have identical functionality."""