Source code for cloudvision.Connector.grpc_client.grpcConnectionPool

# Copyright (c) 2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the COPYING file.

import logging
import threading
from typing import Any, Dict, List, Optional
from cloudvision.Connector.gen import notification_pb2 as ntf
from cloudvision.Connector.gen import router_pb2 as rtr
from cloudvision.Connector.grpc_client.grpcClient import GRPCClient, TIME_TYPE, UPDATE_TYPE, \
    DATASET_TYPE_DEVICE


class _StreamAwareGRPCClient(GRPCClient):
    """Extends GRPCClient with stream counting capability."""
    def __init__(self, grpcAddr, max_streams, **kwargs):
        super().__init__(grpcAddr=grpcAddr, **kwargs)
        self.active_streams = 0
        self.max_streams = max_streams
        self._lock = threading.Lock()
        self._id = id(self)
        logging.debug(f"{self._id} | create new connection")

    def try_reserve_stream(self):
        with self._lock:
            if self.active_streams < self.max_streams:
                self.active_streams += 1
                logging.debug(f"{self._id} | add stream({self.active_streams}/{self.max_streams})")
                return True
            return False

    def release_stream(self):
        with self._lock:
            self.active_streams -= 1
            logging.debug(f"{self._id} | release stream({self.active_streams}/{self.max_streams})")

    def print_client_info(self):
        with self._lock:
            logging.info(f"{self._id} | streams({self.active_streams}/{self.max_streams})")


[docs] class PooledGRPCClient: """ PooledGRPCClient manages a connection pool of GRPCClient instances. It balances long-lived subscription streams using round-robin distribution, and routes unary `get` and `publish` calls without consuming stream slots. It also creates a new GRPCClient instances when all clients in pool has maxed out it's stream usage. If single instance of GRPCClient is being used to have 100+ subscription, it's recommended to use PooledGRPCClient instead. grpcAddr: must be a valid apiserver address in the format <ADDRESS>:<PORT>. certs: if present, must be the path to the cert file. key: if present, must be the path to a .pem key file. ca: if present, must be the path to a root certificate authority file. token: if present, must be the path a .tok user access token. tokenValue: if present, is the actual token in string form. Cannot be set with token certsValue: if present, is the actual certs in string form. Cannot be set with certs keyValue: if present, is the actual key in string form. Cannot be set with key caValue: if present, is the actual ca in string form. Cannot be set with ca """ def __init__( self, grpcAddr, max_streams_per_connection=100, max_connections=21474837, token: Optional[str] = None, certs: Optional[str] = None, key: Optional[str] = None, ca: Optional[str] = None, tokenValue: Optional[str] = None, certsValue: Optional[str] = None, keyValue: Optional[str] = None, caValue: Optional[str] = None, channel_options: Dict[str, Any] = {}, ) -> None: self.grpcAddr = grpcAddr self._max_streams = max_streams_per_connection self._max_connections = max_connections self.token = token self.certs = certs self.key = key self.ca = ca self.tokenValue = tokenValue self.certsValue = certsValue self.keyValue = keyValue self.caValue = caValue self.channel_options = channel_options self._pool: List[_StreamAwareGRPCClient] = [] self._lock = threading.Lock() self._rr_index = 0 # round-robin cursor for subscribe self._unary_rr_index = 0 # round-robin cursor for get/publish def _print_connection_state(self): with self._lock: for client in self._pool: client.print_client_info() def _create_new_client(self): client = _StreamAwareGRPCClient( self.grpcAddr, self._max_streams, token=self.token, certs=self.certs, key=self.key, ca=self.ca, tokenValue=self.tokenValue, certsValue=self.certsValue, keyValue=self.keyValue, caValue=self.caValue, channel_options=self.channel_options, ) self._pool.append(client) return client def _get_or_create_client(self): """ _get_or_create_client returns a connection from connection pool on round-robin basid. If all connections have reached max_streams_per_connection limit, it creates a new connection and add it to connection pool. Raises a RuntimeError if new connection can't be created because max_connections limit is reached. """ with self._lock: pool_size = len(self._pool) if pool_size > 0: start = self._rr_index % pool_size for i in range(pool_size): idx = (start + i) % pool_size client = self._pool[idx] if client.try_reserve_stream(): self._rr_index = (idx + 1) % pool_size return client if len(self._pool) >= self._max_connections: raise RuntimeError("Maximum number of gRPC connections reached") new_client = self._create_new_client() new_client.try_reserve_stream() self._rr_index = len(self._pool) % self._max_connections return new_client def _get_any_client(self): """ Select a connection for unary RPCs (get, publish) without stream reservation. If all connections have reached max_streams_per_connection limit, it creates a new connection and add it to connection pool. Raises a RuntimeError if new connection can't be created because max_connections limit is reached """ with self._lock: pool_size = len(self._pool) if pool_size == 0: return self._create_new_client() start = self._unary_rr_index % pool_size for i in range(pool_size): idx = (start + i) % pool_size client = self._pool[idx] if client.active_streams < client.max_streams: self._unary_rr_index = (idx + 1) % pool_size return client if len(self._pool) >= self._max_connections: raise RuntimeError("Maximum number of gRPC connections reached") return self._create_new_client()
[docs] def subscribe(self, queries, sharding=None): """ Subscribe creates and executes a Subscribe protobuf message, returning a stream of notificationBatch. queries must be a list of querry protobuf messages. sharding, if present must be a protobuf sharding message. """ client = self._get_or_create_client() stream = client.subscribe(queries, sharding) def wrapped_stream(): try: for item in stream: yield item finally: client.release_stream() return wrapped_stream()
[docs] def get( self, queries: List[rtr.Query], start: Optional[TIME_TYPE] = None, end: Optional[TIME_TYPE] = None, versions=0, sharding=None, exact_range=False ): """ Get creates and executes a Get protobuf message, returning a stream of notificationBatch. queries must be a list of querry protobuf messages. start and end, if present, must be nanoseconds timestamps (uint64). sharding, if present must be a protobuf sharding message. Unary get request — uses any client, does not reserve a stream. """ client = self._get_any_client() return client.get( queries=queries, start=start, end=end, versions=versions, sharding=sharding, exact_range=exact_range )
[docs] def publish( self, dId, notifs: List[ntf.Notification], dtype: str = DATASET_TYPE_DEVICE, sync: bool = True, compare: Optional[UPDATE_TYPE] = None ): """ Publish creates and executes a Publish protobuf message. refer to cloudvision/Connector/protobufs/router.proto:124 default to sync publish being true so that changes are reflected Unary publish request — uses any client, does not reserve a stream. """ client = self._get_any_client() return client.publish( dId=dId, notifs=notifs, dtype=dtype, sync=sync, compare=compare )