mirror of
https://github.com/ParisNeo/ollama_proxy_server.git
synced 2025-09-06 05:12:14 +00:00
Update main.py
This commit is contained in:
@@ -21,7 +21,19 @@ import datetime
|
||||
def get_config(filename):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(filename)
|
||||
return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()]
|
||||
return [
|
||||
(
|
||||
name,
|
||||
{
|
||||
'url': config[name]['url'],
|
||||
'max_parallel_connections': int(config[name].get('max_parallel_connections', 10)),
|
||||
'queue_size': int(config[name].get('queue_size', 100)), # Default queue size of 100
|
||||
'queue': Queue(maxsize=int(config[name].get('queue_size', 100))),
|
||||
'active_requests': 0
|
||||
}
|
||||
)
|
||||
for name in config.sections()
|
||||
]
|
||||
|
||||
# Read the authorized users and their keys from a file
|
||||
def get_authorized_users(filename):
|
||||
@@ -38,8 +50,6 @@ def get_authorized_users(filename):
|
||||
ASCIIColors.red(f"User entry broken: {line.strip()}")
|
||||
return authorized_users
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
|
||||
@@ -135,7 +145,6 @@ def main():
|
||||
path = url.path
|
||||
get_params = parse_qs(url.query) or {}
|
||||
|
||||
|
||||
if self.command == "POST":
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
@@ -143,20 +152,28 @@ def main():
|
||||
else:
|
||||
post_params = {}
|
||||
|
||||
|
||||
# Find the server with the lowest number of queue entries.
|
||||
min_queued_server = servers[0]
|
||||
# Find the server with the lowest number of active requests.
|
||||
min_active_server = servers[0]
|
||||
for server in servers:
|
||||
cs = server[1]
|
||||
if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():
|
||||
min_queued_server = server
|
||||
if cs['active_requests'] < min_active_server[1]['active_requests']:
|
||||
min_active_server = server
|
||||
|
||||
# Apply the queuing mechanism only for a specific endpoint.
|
||||
if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions':
|
||||
que = min_queued_server[1]['queue']
|
||||
cs = min_active_server[1]
|
||||
client_ip, client_port = self.client_address
|
||||
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
||||
que.put_nowait(1)
|
||||
try:
|
||||
# Try to acquire the queue slot for this request.
|
||||
cs['queue'].put_nowait(1)
|
||||
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'])
|
||||
except Queue.Full:
|
||||
# If the queue is full, log and return a 503 Service Unavailable response.
|
||||
self.add_access_log_entry(event="gen_error", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'], error="Queue is full")
|
||||
self.send_response(503)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
try:
|
||||
post_data_dict = {}
|
||||
|
||||
@@ -164,22 +181,19 @@ def main():
|
||||
post_data_str = post_data.decode('utf-8')
|
||||
post_data_dict = json.loads(post_data_str)
|
||||
|
||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
|
||||
response = requests.request(self.command, cs['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
|
||||
self._send_response(response)
|
||||
except Exception as ex:
|
||||
self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex)
|
||||
finally:
|
||||
que.get_nowait()
|
||||
self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
||||
cs['queue'].get_nowait()
|
||||
self.add_access_log_entry(event="gen_done", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'])
|
||||
else:
|
||||
# For other endpoints, just mirror the request.
|
||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
|
||||
response = requests.request(self.command, min_active_server[1]['url'] + path, params=get_params, data=post_params)
|
||||
self._send_response(response)
|
||||
|
||||
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
pass
|
||||
|
||||
|
||||
print('Starting server')
|
||||
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
|
||||
print(f'Running server on port {args.port}')
|
||||
|
Reference in New Issue
Block a user