# Copyright (c) 2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the COPYING file.
import logging
import threading
from typing import Any, Dict, List, Optional
from cloudvision.Connector.gen import notification_pb2 as ntf
from cloudvision.Connector.gen import router_pb2 as rtr
from cloudvision.Connector.grpc_client.grpcClient import GRPCClient, TIME_TYPE, UPDATE_TYPE, \
DATASET_TYPE_DEVICE
class _StreamAwareGRPCClient(GRPCClient):
"""Extends GRPCClient with stream counting capability."""
def __init__(self, grpcAddr, max_streams, **kwargs):
super().__init__(grpcAddr=grpcAddr, **kwargs)
self.active_streams = 0
self.max_streams = max_streams
self._lock = threading.Lock()
self._id = id(self)
logging.debug(f"{self._id} | create new connection")
def try_reserve_stream(self):
with self._lock:
if self.active_streams < self.max_streams:
self.active_streams += 1
logging.debug(f"{self._id} | add stream({self.active_streams}/{self.max_streams})")
return True
return False
def release_stream(self):
with self._lock:
self.active_streams -= 1
logging.debug(f"{self._id} | release stream({self.active_streams}/{self.max_streams})")
def print_client_info(self):
with self._lock:
logging.info(f"{self._id} | streams({self.active_streams}/{self.max_streams})")
[docs]
class PooledGRPCClient:
"""
PooledGRPCClient manages a connection pool of GRPCClient instances.
It balances long-lived subscription streams using round-robin distribution,
and routes unary `get` and `publish` calls without consuming stream slots.
It also creates a new GRPCClient instances when all clients in pool has
maxed out it's stream usage.
If single instance of GRPCClient is being used to have 100+ subscription, it's
recommended to use PooledGRPCClient instead.
grpcAddr: must be a valid apiserver address in the format <ADDRESS>:<PORT>.
certs: if present, must be the path to the cert file.
key: if present, must be the path to a .pem key file.
ca: if present, must be the path to a root certificate authority file.
token: if present, must be the path a .tok user access token.
tokenValue: if present, is the actual token in string form. Cannot be set with token
certsValue: if present, is the actual certs in string form. Cannot be set with certs
keyValue: if present, is the actual key in string form. Cannot be set with key
caValue: if present, is the actual ca in string form. Cannot be set with ca
"""
def __init__(
self,
grpcAddr,
max_streams_per_connection=100,
max_connections=21474837,
token: Optional[str] = None,
certs: Optional[str] = None,
key: Optional[str] = None,
ca: Optional[str] = None,
tokenValue: Optional[str] = None,
certsValue: Optional[str] = None,
keyValue: Optional[str] = None,
caValue: Optional[str] = None,
channel_options: Dict[str, Any] = {},
) -> None:
self.grpcAddr = grpcAddr
self._max_streams = max_streams_per_connection
self._max_connections = max_connections
self.token = token
self.certs = certs
self.key = key
self.ca = ca
self.tokenValue = tokenValue
self.certsValue = certsValue
self.keyValue = keyValue
self.caValue = caValue
self.channel_options = channel_options
self._pool: List[_StreamAwareGRPCClient] = []
self._lock = threading.Lock()
self._rr_index = 0 # round-robin cursor for subscribe
self._unary_rr_index = 0 # round-robin cursor for get/publish
def _print_connection_state(self):
with self._lock:
for client in self._pool:
client.print_client_info()
def _create_new_client(self):
client = _StreamAwareGRPCClient(
self.grpcAddr,
self._max_streams,
token=self.token,
certs=self.certs,
key=self.key,
ca=self.ca,
tokenValue=self.tokenValue,
certsValue=self.certsValue,
keyValue=self.keyValue,
caValue=self.caValue,
channel_options=self.channel_options,
)
self._pool.append(client)
return client
def _get_or_create_client(self):
"""
_get_or_create_client returns a connection from connection pool on round-robin basid.
If all connections have reached max_streams_per_connection limit, it
creates a new connection and add it to connection pool.
Raises a RuntimeError if new connection can't be created because max_connections limit
is reached.
"""
with self._lock:
pool_size = len(self._pool)
if pool_size > 0:
start = self._rr_index % pool_size
for i in range(pool_size):
idx = (start + i) % pool_size
client = self._pool[idx]
if client.try_reserve_stream():
self._rr_index = (idx + 1) % pool_size
return client
if len(self._pool) >= self._max_connections:
raise RuntimeError("Maximum number of gRPC connections reached")
new_client = self._create_new_client()
new_client.try_reserve_stream()
self._rr_index = len(self._pool) % self._max_connections
return new_client
def _get_any_client(self):
"""
Select a connection for unary RPCs (get, publish) without stream reservation.
If all connections have reached max_streams_per_connection limit, it
creates a new connection and add it to connection pool.
Raises a RuntimeError if new connection can't be created because max_connections limit
is reached
"""
with self._lock:
pool_size = len(self._pool)
if pool_size == 0:
return self._create_new_client()
start = self._unary_rr_index % pool_size
for i in range(pool_size):
idx = (start + i) % pool_size
client = self._pool[idx]
if client.active_streams < client.max_streams:
self._unary_rr_index = (idx + 1) % pool_size
return client
if len(self._pool) >= self._max_connections:
raise RuntimeError("Maximum number of gRPC connections reached")
return self._create_new_client()
[docs]
def subscribe(self, queries, sharding=None):
"""
Subscribe creates and executes a Subscribe protobuf message,
returning a stream of notificationBatch.
queries must be a list of querry protobuf messages.
sharding, if present must be a protobuf sharding message.
"""
client = self._get_or_create_client()
stream = client.subscribe(queries, sharding)
def wrapped_stream():
try:
for item in stream:
yield item
finally:
client.release_stream()
return wrapped_stream()
[docs]
def getAndSubscribe(
self,
queries: List[rtr.Query],
start: Optional[TIME_TYPE] = None,
versions=0,
sharding=None,
exact_range=False,
timeout: Optional[float] = None,
):
"""
GetAndSubscribe creates and executes a GetAndSubscribe protobuf message,
returning a stream of notificationBatch. This will initially consist of notifications
of the current state in CloudVision (i.e. a Get), after which it will transition to a
subscription method, where update notifications will be received for changes occurring to
the queried paths.
queries must be a list of query protobuf messages.
start, if present, must be a google.protobuf.timestamp_pb2.Timestamp or a datetime object.
versions, if present, specifies the maximum number of versions to retrieve.
sharding, if present, must be a protobuf sharding message.
exact_range, if present, specifies whether to return the initial state at time `start`.
timeout: if present, sets the GRPC timeout in seconds. Default is None (no timeout).
"""
client = self._get_or_create_client()
stream = client.getAndSubscribe(
queries=queries,
start=start,
versions=versions,
sharding=sharding,
exact_range=exact_range,
timeout=timeout,
)
def wrapped_stream():
try:
for item in stream:
yield item
finally:
client.release_stream()
return wrapped_stream()
[docs]
def get(
self,
queries: List[rtr.Query],
start: Optional[TIME_TYPE] = None,
end: Optional[TIME_TYPE] = None,
versions=0,
sharding=None,
exact_range=False
):
"""
Get creates and executes a Get protobuf message, returning a stream of
notificationBatch.
queries must be a list of querry protobuf messages.
start and end, if present, must be nanoseconds timestamps (uint64).
versions, if present, specifies the maximum number of versions to retrieve.
sharding, if present, must be a protobuf sharding message.
Unary get request — uses any client, does not reserve a stream.
"""
client = self._get_any_client()
return client.get(
queries=queries,
start=start,
end=end,
versions=versions,
sharding=sharding,
exact_range=exact_range
)
[docs]
def publish(
self,
dId,
notifs: List[ntf.Notification],
dtype: str = DATASET_TYPE_DEVICE,
sync: bool = True,
compare: Optional[UPDATE_TYPE] = None
):
"""
Publish creates and executes a Publish protobuf message.
refer to cloudvision/Connector/protobufs/router.proto:124
default to sync publish being true so that changes are reflected
Unary publish request — uses any client, does not reserve a stream.
"""
client = self._get_any_client()
return client.publish(
dId=dId,
notifs=notifs,
dtype=dtype,
sync=sync,
compare=compare
)