diff --git a/ollama_proxy_server/main.py b/ollama_proxy_server/main.py index e6378b0..36353f9 100644 --- a/ollama_proxy_server/main.py +++ b/ollama_proxy_server/main.py @@ -21,7 +21,19 @@ import datetime def get_config(filename): config = configparser.ConfigParser() config.read(filename) - return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()] + return [ + ( + name, + { + 'url': config[name]['url'], + 'max_parallel_connections': int(config[name].get('max_parallel_connections', 10)), + 'queue_size': int(config[name].get('queue_size', 100)), # Default queue size of 100 + 'queue': Queue(maxsize=int(config[name].get('queue_size', 100))), + 'active_requests': 0 + } + ) + for name in config.sections() + ] # Read the authorized users and their keys from a file def get_authorized_users(filename): @@ -29,17 +41,15 @@ def get_authorized_users(filename): lines = f.readlines() authorized_users = {} for line in lines: - if line=="": + if line == "": continue try: user, key = line.strip().split(':') authorized_users[user] = key except: - ASCIIColors.red(f"User entry broken:{line.strip()}") + ASCIIColors.red(f"User entry broken: {line.strip()}") return authorized_users - - def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', default="config.ini", help='Path to the authorized users list') @@ -48,7 +58,7 @@ def main(): parser.add_argument('--port', type=int, default=8000, help='Port number for the server') parser.add_argument('-d', '--deactivate_security', action='store_true', help='Deactivates security') args = parser.parse_args() - servers = get_config(args.config) + servers = get_config(args.config) authorized_users = get_authorized_users(args.users_list) deactivate_security = args.deactivate_security ASCIIColors.red("Ollama Proxy server") @@ -57,17 +67,17 @@ def main(): class RequestHandler(BaseHTTPRequestHandler): def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""): log_file_path = Path(args.log_path) - + if not log_file_path.exists(): with open(log_file_path, mode='w', newline='') as csvfile: fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() - + with open(log_file_path, mode='a', newline='') as csvfile: fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error} + row = {'time_stamp': str(datetime.datetime.now()), 'event': event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error} writer.writerow(row) def _send_response(self, response): @@ -105,7 +115,7 @@ def main(): return False token = auth_header.split(' ')[1] user, key = token.split(':') - + # Check if the user and key are in the list of authorized users if authorized_users.get(user) == key: self.user = user @@ -115,7 +125,7 @@ def main(): return False except: return False - + def proxy(self): self.user = "unknown" if not deactivate_security and not self._validate_user_and_key(): @@ -126,16 +136,15 @@ def main(): if not auth_header or not auth_header.startswith('Bearer '): self.add_access_log_entry(event='rejected', user="unknown", ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed") else: - token = auth_header.split(' ')[1] + token = auth_header.split(' ')[1] self.add_access_log_entry(event='rejected', user=token, ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed") self.send_response(403) self.end_headers() - return + return url = urlparse(self.path) path = url.path get_params = parse_qs(url.query) or {} - if self.command == "POST": content_length = int(self.headers['Content-Length']) post_data = self.rfile.read(content_length) @@ -143,20 +152,28 @@ def main(): else: post_params = {} - - # Find the server with the lowest number of queue entries. - min_queued_server = servers[0] + # Find the server with the lowest number of active requests. + min_active_server = servers[0] for server in servers: cs = server[1] - if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize(): - min_queued_server = server + if cs['active_requests'] < min_active_server[1]['active_requests']: + min_active_server = server # Apply the queuing mechanism only for a specific endpoint. if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions': - que = min_queued_server[1]['queue'] + cs = min_active_server[1] client_ip, client_port = self.client_address - self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize()) - que.put_nowait(1) + try: + # Try to acquire the queue slot for this request. + cs['queue'].put_nowait(1) + self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests']) + except Queue.Full: + # If the queue is full, log and return a 503 Service Unavailable response. + self.add_access_log_entry(event="gen_error", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'], error="Queue is full") + self.send_response(503) + self.end_headers() + return + try: post_data_dict = {} @@ -164,22 +181,19 @@ def main(): post_data_str = post_data.decode('utf-8') post_data_dict = json.loads(post_data_str) - response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False)) + response = requests.request(self.command, cs['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False)) self._send_response(response) - except Exception as ex: - self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex) finally: - que.get_nowait() - self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize()) + cs['queue'].get_nowait() + self.add_access_log_entry(event="gen_done", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests']) else: # For other endpoints, just mirror the request. - response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params) + response = requests.request(self.command, min_active_server[1]['url'] + path, params=get_params, data=post_params) self._send_response(response) class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): pass - print('Starting server') server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here. print(f'Running server on port {args.port}')