Source code for pyvrs.filter

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from bisect import bisect
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Set, Union

from . import ImageConversion, RecordType

from .base import BaseVRSReader
from .record import VRSRecord
from .slice import VRSReaderSlice
from .utils import (
    get_recordable_type_id_name,
    string_of_set,
    tags_to_justified_table_str,
)

__all__ = [
    "FilteredVRSReader",
    "SyncFilteredVRSReader",
    "RecordFilter",
]


[docs]@dataclass class RecordFilter: """RecordFilter represents a filter that's applied to the VRS file.""" record_types: Set[str] stream_ids: Set[str] min_timestamp: float max_timestamp: float
[docs]class FilteredVRSReader(BaseVRSReader, ABC): """FilteredVRSReader represents subset of VRSReader after applying filter. This class essentially has the exact same methods as VRSReader but operates against subset of the file. Note that you can't `re-filter` an already filtered VRSReader. """ def __init__(self, reader: BaseVRSReader, record_filter: RecordFilter): """ Args: reader: reader object for the whole VRS file (i.e. without any filters) record_fitler: filter that's applied to the VRS file """ self._reader = reader self._record_filter = record_filter self._filtered_indices = self._generate_filtered_indices() self._min_timestamp = ( 0 if len(self._filtered_indices) == 0 else self._reader.get_timestamp_for_index(self._filtered_indices[0]) ) self._max_timestamp = ( 0 if len(self._filtered_indices) == 0 else self._reader.get_timestamp_for_index(self._filtered_indices[-1]) ) def __getitem__(self, i: Union[int, slice]) -> Union[VRSRecord, VRSReaderSlice]: raise NotImplementedError() def __len__(self) -> int: return self.n_records def __repr__(self) -> str: raise NotImplementedError() def __str__(self) -> str: raise NotImplementedError() @property def file_tags(self) -> Mapping[str, str]: """ Return a dict of all file tags present in this VRS file. Returns: Dictionary of all file tags: {<tag>: <value>} """ return self._reader.file_tags @property def stream_tags(self) -> Mapping[str, Mapping[str, Any]]: """ Return a dict of all file tags present in this VRS file. Returns: Dictionary of all file tags: {<tag>: <value>} """ return self._reader.stream_tags @property def n_records(self) -> int: """The number of records that this reader is configured to read, given current filters.""" return len(self._filtered_indices) @property def record_types(self) -> Set[str]: """The set of record types that this reader is configured to read, given current filters. """ return self._record_filter.record_types @property def stream_ids(self) -> Set[str]: """The set of stream ids that this reader is configured to read, given current filters. """ return self._record_filter.stream_ids @property def min_timestamp(self) -> float: """The timestamp for the first record that this reader is configured to read, given current filters. Return 0 if there are no record. """ return self._min_timestamp @property def max_timestamp(self) -> float: """The timestamp for the last record that this reader is configured to read, given current filters. Return 0 if there are no record. """ return self._max_timestamp @property def min_filter_timestamp(self) -> float: """The value of min_timestamp when user called filter method on VRSReader.""" return self._record_filter.min_timestamp @property def max_filter_timestamp(self) -> float: """The value of max_timestamp when user called filter method on VRSReader.""" return self._record_filter.max_timestamp
[docs] def find_stream( self, recordable_type_id: int, tag_name: str, tag_value: str ) -> str: """ Find stream matching recordable type and tag, and return its stream id. This call isn't affected by the filter. Args: recordable_type_id: stream_id is `<recordable_type_id>-<instance_id>` tag_name: tag name that you are interested in tag_value: tag value that you are interested in Returns: Stream ID that starts with recordable_type_id and has a given tag pair. """ return self._reader.find_stream(recordable_type_id, tag_name, tag_value)
[docs] def find_streams(self, recordable_type_id: int, flavor: str = "") -> List[str]: """ Find streams matching recordable type and flavor, and return sets of stream ids. This call isn't affected by the filter. Args: recordable_type_id: stream_id is `<recordable_type_id>-<instance_id>` tag_name: tag name that you are interested in tag_value: tag value that you are interested in Returns: A set of stream IDs that start with recordable_type_id and has a given flavor. """ return self._reader.find_streams(recordable_type_id, flavor)
[docs] def get_stream_info(self, stream_id: str) -> Mapping[str, str]: """ Get details about a stream. Args: stream_id: stream_id you are interested in. Returns: An information about the stream in a dictionary. """ return self._reader.get_stream_info(stream_id)
[docs] def get_records_count(self, stream_id: str, record_type: RecordType) -> int: """ Get the number of records for the stream_id & record_type. Args: stream_id: stream_id you are interested in. record_type: record type you are interested in. Returns: The number of records for stream_id & record type """ return self._reader.get_records_count(stream_id, record_type)
[docs] def get_timestamp_list(self) -> List[float]: """ Get the list of timestamps corresponding to the given indices. Args: indices: the list of indices we want to get the timestamp. Returns: A list of timestamps correspond to the indices, if indices are None, we get the full timestamp list. """ return self._reader.get_timestamp_list(self._filtered_indices)
[docs] def get_timestamp_for_index(self, index: int) -> float: """ Get the timestamp corresponding to the given index. Args: index: the index for the record Returns: A timestamp corresponds to the index """ if index >= len(self._filtered_indices): raise IndexError("Index {} is out of range.".format(index)) return self._reader.get_timestamp_for_index(self._filtered_indices[index])
[docs] def set_image_conversion(self, conversion: ImageConversion) -> None: """ Set default image conversion policy, and clears any stream specific setting. Args: conversion: The image conversion you want to apply for all streams. """ raise NotImplementedError( "This should be called in SyncVRSReader before applying filter." )
[docs] def set_stream_image_conversion( self, stream_id: str, conversion: ImageConversion ) -> None: """ Set image conversion policy for a specific stream. Args: stream_id: The stream_id you want to apply image conversion to. conversion: The image conversion you want to apply for a specific stream. """ raise NotImplementedError( "This should be called in SyncVRSReader before applying filter." )
[docs] def set_stream_type_image_conversion( self, recordable_type_id: str, conversion: ImageConversion ) -> int: """ Set image conversion policy for streams of a specific device type. Args: recordable_type_id: The recordable_type_id you want to apply image conversion to. If you specify 1000, streams with id 1000-* are the targets. conversion: The image conversion you want to apply for a specific stream. Returns: The number of streams affected. """ raise NotImplementedError( "This should be called in SyncVRSReader before applying filter." )
[docs] def might_contain_images(self, stream_id: str) -> bool: """ Check if the given stream_id contains an image data. Args: stream_id: stream_id that you are interested in. Returns: Based on the config record, return if the stream contains an image data. """ return self._reader.might_contain_images(stream_id)
[docs] def might_contain_audio(self, stream_id: str) -> bool: """ Check if the given stream_id contains an audio data. Args: stream_id: stream_id that you are interested in. Returns: Based on the config record, return if the stream contains an audio data. """ return self._reader.might_contain_audio(stream_id)
[docs] def get_estimated_frame_rate(self, stream_id: str) -> float: """ Get the estimated frame rate for the given stream_id. Args: stream_id: stream_id that you are interested in. Returns: The estimated frame rate. """ return self._reader.get_estimated_frame_rate(stream_id)
[docs] def get_record_index_by_time( self, stream_id: str, timestamp: float, epsilon: Optional[float] = None, record_type: Optional[RecordType] = None, ) -> int: """ Get index in filtered records by timestamp. Args: stream_id: stream_id that you are interested in. timestamp: timestamp that you are interested in. epsilon: Optional argument. If specified we search for record in range of (timestamp-epsilon)-(timestamp+epsilon) and returns the nearest record. record_type: Optional argument. If specified we search for record with the record_type. """ assert stream_id in self.stream_ids index = self._reader.get_record_index_by_time( stream_id, timestamp, epsilon, record_type ) # TODO: Fix this logic return max(bisect(self._filtered_indices, index) - 1, 0)
[docs] def read_record_by_time( self, stream_id: str, timestamp: float, epsilon: Optional[float] = None, record_type: Optional[RecordType] = None, ) -> VRSRecord: """ Read record by timestamp. Args: stream_id: stream_id that you are interested in. timestamp: timestamp that you are interested in. epsilon: Optional argument. If specified we search for record in range of (timestamp-epsilon)-(timestamp+epsilon) and returns the nearest record. record_type: Optional argument. If specified we search for record with the record_type. Returns: VRSRecord corresponds to the stream_id & timestamp. Raises: TimestampNotFoundError: If epsilon is not None and the record doesn't exist within the time range. ValueError: If epsilon is None and the record isn't found using lower_bound. """ return self._reader.read_record_by_time( stream_id, timestamp, epsilon, record_type )
[docs] def read_prev_record( self, stream_id: str, record_type: str, index: int ) -> Optional[VRSRecord]: """ Read the last record that matches stream_id and record_type and its index is smaller or equal than given index. Args: stream_id: stream_id that you are interested in. record_type: record_type that you are interested in. index: the absolute index in the file. Based on this index, try to find the previous record that matches stream_id & record_type Returns: VRSRecord if there is a record, otherwise None """ index_in_all = self._filtered_indices[index] return self._reader.read_prev_record(stream_id, record_type, index_in_all)
[docs] def read_next_record( self, stream_id: str, record_type: str, index: int ) -> Optional[VRSRecord]: """ Read the first record that matches stream_id and record_type and its index is greater or equal than given index. Args: stream_id: stream_id that you are interested in. record_type: record_type that you are interested in. index: the absolute index in the file. Based on this index, try to find the previous record that matches stream_id & record_type Returns: VRSRecord if there is a record, otherwise None """ try: index_in_all = self._filtered_indices[index] except IndexError as e: print(e) return None return self._reader.read_next_record(stream_id, record_type, index_in_all)
@abstractmethod def _read_record( self, indices: List[int], i: Union[int, slice] ) -> Union[VRSRecord, VRSReaderSlice]: raise NotImplementedError() def _record_count_by_type_from_stream_id(self, stream_id: str) -> Mapping[str, int]: return self._reader._record_count_by_type_from_stream_id(stream_id) def _generate_filtered_indices(self) -> List[int]: return self._reader._generate_filtered_indices(self._record_filter)
[docs]class SyncFilteredVRSReader(FilteredVRSReader): """ Synchrnous version of FilteredVRSReader. """ def __getitem__(self, i: Union[int, slice]) -> Union[VRSRecord, VRSReaderSlice]: return self._read_record(self._filtered_indices, i) def __repr__(self) -> str: return ( f"SyncFilteredVRSReader({self._reader._path!r}, " f"auto_read_configuration_records={self._reader._auto_read_configuration_records!r})" f"filter={self._record_filter!r}" ) def __str__(self) -> str: s = "\n".join( [ self._reader._path, tags_to_justified_table_str(self.file_tags), f"{len(self)}/{len(self._reader)} records are enabled (based on filters)", "Automatic configuration record reading is {}".format( "enabled" if self._reader._auto_read_configuration_records else "disabled" ), "Available Stream IDs: {}".format( string_of_set(self._reader.stream_ids) ), " Enabled Stream IDs: {}".format(string_of_set(self.stream_ids)), "Available Record Types: {}".format( string_of_set(self._reader.record_types) ), " Enabled Record Types: {}".format(string_of_set(self.record_types)), ] ) if len(self) > 0: s += "\n" + "\n".join( [ "{:.2f}s of available records: {:.2f}s - {:.2f}s".format( self._reader.time_range, self._reader.min_timestamp, self._reader.max_timestamp, ), "{:.2f}s of enabled records: {:.2f}s - {:.2f}s".format( self.time_range, self.min_timestamp, self.max_timestamp, ), ] ) s += "\n\nAvailable Streams in VRS: \n" for stream_id in self.stream_ids: s += " Stream ID: {} ({} no. {})\n No. of records {}\n".format( stream_id, get_recordable_type_id_name(stream_id), stream_id.split("-")[1], self._record_count_by_type_from_stream_id(stream_id), ) return s def _read_record(self, indices: List[int], i: Union[int, slice]): return self._reader._read_record(indices, i)
class AsyncFilteredVRSReader(FilteredVRSReader): def __init__(self, reader: BaseVRSReader, record_filter: RecordFilter): super().__init__(reader, record_filter) _async_read_record_op = getattr(self._reader, "_async_read_record", None) if not callable(_async_read_record_op): raise NotImplementedError( "AsyncFilteredVRSReader._reader doesn't have method _async_read_record." " You should only construct AsyncFilteredVRSReader via AsyncVRSReader.filtered_by_fields method." ) def __aiter__(self): self._index = 0 return self async def __anext__(self): if self._index == len(self): raise StopAsyncIteration result = await self[self._index] self._index += 1 return result async def __getitem__( self, i: Union[int, slice] ) -> Union[VRSRecord, VRSReaderSlice]: return await self._reader._async_read_record(self._filtered_indices, i) def __repr__(self) -> str: return ( f"AsyncFilteredVRSReader({self._reader._path!r}, " f"auto_read_configuration_records={self._reader._auto_read_configuration_records!r})" f"filter={self._record_filter!r}" ) def __str__(self) -> str: s = "\n".join( [ self._reader._path, tags_to_justified_table_str(self.file_tags), f"{len(self)}/{len(self._reader)} records are enabled (based on filters)", "Automatic configuration record reading is {}".format( "enabled" if self._auto_read_configuration_records else "disabled" ), "Available Stream IDs: {}".format( string_of_set(self._reader.stream_ids) ), " Enabled Stream IDs: {}".format(string_of_set(self.stream_ids)), "Available Record Types: {}".format( string_of_set(self._reader.record_types) ), " Enabled Record Types: {}".format(string_of_set(self.record_types)), ] ) if len(self) > 0: s += "\n" + "\n".join( [ "{:.2f}s of available records: {:.2f}s - {:.2f}s".format( self._reader.time_range, self._reader.min_timestamp, self._reader.max_timestamp, ), "{:.2f}s of enabled records: {:.2f}s - {:.2f}s".format( self.time_range, self.min_timestamp, self.max_timestamp, ), ] ) s += "\n\nAvailable Streams in VRS: \n" for stream_id in self.stream_ids: s += " Stream ID: {} ({} no. {})\n No. of records {}\n".format( stream_id, get_recordable_type_id_name(stream_id), stream_id.split("-")[1], self._record_count_by_type_from_stream_id(stream_id), ) return s def _read_record(self, indices: List[int], i: Union[int, slice]): raise NotImplementedError()