Source code for cloudvision.Connector.grpc_client.grpcConnectionPool

# Copyright (c) 2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the COPYING file.

import logging
import threading
from typing import Any, Dict, List, Optional
from cloudvision.Connector.gen import notification_pb2 as ntf
from cloudvision.Connector.gen import router_pb2 as rtr
from cloudvision.Connector.grpc_client.grpcClient import GRPCClient, TIME_TYPE, UPDATE_TYPE, \
    DATASET_TYPE_DEVICE


class _StreamAwareGRPCClient(GRPCClient):
    """Extends GRPCClient with stream counting capability."""
    def __init__(self, grpcAddr, max_streams, **kwargs):
        super().__init__(grpcAddr=grpcAddr, **kwargs)
        self.active_streams = 0
        self.max_streams = max_streams
        self._lock = threading.Lock()
        self._id = id(self)
        logging.debug(f"{self._id} | create new connection")

    def try_reserve_stream(self):
        with self._lock:
            if self.active_streams < self.max_streams:
                self.active_streams += 1
                logging.debug(f"{self._id} | add stream({self.active_streams}/{self.max_streams})")
                return True
            return False

    def release_stream(self):
        with self._lock:
            self.active_streams -= 1
            logging.debug(f"{self._id} | release stream({self.active_streams}/{self.max_streams})")

    def print_client_info(self):
        with self._lock:
            logging.info(f"{self._id} | streams({self.active_streams}/{self.max_streams})")


[docs] class PooledGRPCClient: """ PooledGRPCClient manages a connection pool of GRPCClient instances. It balances long-lived subscription streams using round-robin distribution, and routes unary `get` and `publish` calls without consuming stream slots. It also creates a new GRPCClient instances when all clients in pool has maxed out it's stream usage. If single instance of GRPCClient is being used to have 100+ subscription, it's recommended to use PooledGRPCClient instead. grpcAddr: must be a valid apiserver address in the format <ADDRESS>:<PORT>. certs: if present, must be the path to the cert file. key: if present, must be the path to a .pem key file. ca: if present, must be the path to a root certificate authority file. token: if present, must be the path a .tok user access token. tokenValue: if present, is the actual token in string form. Cannot be set with token certsValue: if present, is the actual certs in string form. Cannot be set with certs keyValue: if present, is the actual key in string form. Cannot be set with key caValue: if present, is the actual ca in string form. Cannot be set with ca """ def __init__( self, grpcAddr, max_streams_per_connection=100, max_connections=21474837, token: Optional[str] = None, certs: Optional[str] = None, key: Optional[str] = None, ca: Optional[str] = None, tokenValue: Optional[str] = None, certsValue: Optional[str] = None, keyValue: Optional[str] = None, caValue: Optional[str] = None, channel_options: Dict[str, Any] = {}, ) -> None: self.grpcAddr = grpcAddr self._max_streams = max_streams_per_connection self._max_connections = max_connections self.token = token self.certs = certs self.key = key self.ca = ca self.tokenValue = tokenValue self.certsValue = certsValue self.keyValue = keyValue self.caValue = caValue self.channel_options = channel_options self._pool: List[_StreamAwareGRPCClient] = [] self._lock = threading.Lock() self._rr_index = 0 # round-robin cursor for subscribe self._unary_rr_index = 0 # round-robin cursor for get/publish def _print_connection_state(self): with self._lock: for client in self._pool: client.print_client_info() def _create_new_client(self): client = _StreamAwareGRPCClient( self.grpcAddr, self._max_streams, token=self.token, certs=self.certs, key=self.key, ca=self.ca, tokenValue=self.tokenValue, certsValue=self.certsValue, keyValue=self.keyValue, caValue=self.caValue, channel_options=self.channel_options, ) self._pool.append(client) return client def _get_or_create_client(self): """ _get_or_create_client returns a connection from connection pool on round-robin basid. If all connections have reached max_streams_per_connection limit, it creates a new connection and add it to connection pool. Raises a RuntimeError if new connection can't be created because max_connections limit is reached. """ with self._lock: pool_size = len(self._pool) if pool_size > 0: start = self._rr_index % pool_size for i in range(pool_size): idx = (start + i) % pool_size client = self._pool[idx] if client.try_reserve_stream(): self._rr_index = (idx + 1) % pool_size return client if len(self._pool) >= self._max_connections: raise RuntimeError("Maximum number of gRPC connections reached") new_client = self._create_new_client() new_client.try_reserve_stream() self._rr_index = len(self._pool) % self._max_connections return new_client def _get_any_client(self): """ Select a connection for unary RPCs (get, publish) without stream reservation. If all connections have reached max_streams_per_connection limit, it creates a new connection and add it to connection pool. Raises a RuntimeError if new connection can't be created because max_connections limit is reached """ with self._lock: pool_size = len(self._pool) if pool_size == 0: return self._create_new_client() start = self._unary_rr_index % pool_size for i in range(pool_size): idx = (start + i) % pool_size client = self._pool[idx] if client.active_streams < client.max_streams: self._unary_rr_index = (idx + 1) % pool_size return client if len(self._pool) >= self._max_connections: raise RuntimeError("Maximum number of gRPC connections reached") return self._create_new_client()
[docs] def subscribe(self, queries, sharding=None): """ Subscribe creates and executes a Subscribe protobuf message, returning a stream of notificationBatch. queries must be a list of querry protobuf messages. sharding, if present must be a protobuf sharding message. """ client = self._get_or_create_client() stream = client.subscribe(queries, sharding) def wrapped_stream(): try: for item in stream: yield item finally: client.release_stream() return wrapped_stream()
[docs] def getAndSubscribe( self, queries: List[rtr.Query], start: Optional[TIME_TYPE] = None, versions=0, sharding=None, exact_range=False, timeout: Optional[float] = None, ): """ GetAndSubscribe creates and executes a GetAndSubscribe protobuf message, returning a stream of notificationBatch. This will initially consist of notifications of the current state in CloudVision (i.e. a Get), after which it will transition to a subscription method, where update notifications will be received for changes occurring to the queried paths. queries must be a list of query protobuf messages. start, if present, must be a google.protobuf.timestamp_pb2.Timestamp or a datetime object. versions, if present, specifies the maximum number of versions to retrieve. sharding, if present, must be a protobuf sharding message. exact_range, if present, specifies whether to return the initial state at time `start`. timeout: if present, sets the GRPC timeout in seconds. Default is None (no timeout). """ client = self._get_or_create_client() stream = client.getAndSubscribe( queries=queries, start=start, versions=versions, sharding=sharding, exact_range=exact_range, timeout=timeout, ) def wrapped_stream(): try: for item in stream: yield item finally: client.release_stream() return wrapped_stream()
[docs] def get( self, queries: List[rtr.Query], start: Optional[TIME_TYPE] = None, end: Optional[TIME_TYPE] = None, versions=0, sharding=None, exact_range=False ): """ Get creates and executes a Get protobuf message, returning a stream of notificationBatch. queries must be a list of querry protobuf messages. start and end, if present, must be nanoseconds timestamps (uint64). versions, if present, specifies the maximum number of versions to retrieve. sharding, if present, must be a protobuf sharding message. Unary get request — uses any client, does not reserve a stream. """ client = self._get_any_client() return client.get( queries=queries, start=start, end=end, versions=versions, sharding=sharding, exact_range=exact_range )
[docs] def publish( self, dId, notifs: List[ntf.Notification], dtype: str = DATASET_TYPE_DEVICE, sync: bool = True, compare: Optional[UPDATE_TYPE] = None ): """ Publish creates and executes a Publish protobuf message. refer to cloudvision/Connector/protobufs/router.proto:124 default to sync publish being true so that changes are reflected Unary publish request — uses any client, does not reserve a stream. """ client = self._get_any_client() return client.publish( dId=dId, notifs=notifs, dtype=dtype, sync=sync, compare=compare )