mirror of
https://github.com/ParisNeo/ollama_proxy_server.git
synced 2025-09-06 05:12:14 +00:00
Update main.py
This commit is contained in:
@@ -21,7 +21,19 @@ import datetime
|
|||||||
def get_config(filename):
|
def get_config(filename):
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read(filename)
|
config.read(filename)
|
||||||
return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()]
|
return [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
{
|
||||||
|
'url': config[name]['url'],
|
||||||
|
'max_parallel_connections': int(config[name].get('max_parallel_connections', 10)),
|
||||||
|
'queue_size': int(config[name].get('queue_size', 100)), # Default queue size of 100
|
||||||
|
'queue': Queue(maxsize=int(config[name].get('queue_size', 100))),
|
||||||
|
'active_requests': 0
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for name in config.sections()
|
||||||
|
]
|
||||||
|
|
||||||
# Read the authorized users and their keys from a file
|
# Read the authorized users and their keys from a file
|
||||||
def get_authorized_users(filename):
|
def get_authorized_users(filename):
|
||||||
@@ -29,17 +41,15 @@ def get_authorized_users(filename):
|
|||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
authorized_users = {}
|
authorized_users = {}
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line=="":
|
if line == "":
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
user, key = line.strip().split(':')
|
user, key = line.strip().split(':')
|
||||||
authorized_users[user] = key
|
authorized_users[user] = key
|
||||||
except:
|
except:
|
||||||
ASCIIColors.red(f"User entry broken:{line.strip()}")
|
ASCIIColors.red(f"User entry broken: {line.strip()}")
|
||||||
return authorized_users
|
return authorized_users
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
|
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
|
||||||
@@ -48,7 +58,7 @@ def main():
|
|||||||
parser.add_argument('--port', type=int, default=8000, help='Port number for the server')
|
parser.add_argument('--port', type=int, default=8000, help='Port number for the server')
|
||||||
parser.add_argument('-d', '--deactivate_security', action='store_true', help='Deactivates security')
|
parser.add_argument('-d', '--deactivate_security', action='store_true', help='Deactivates security')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
servers = get_config(args.config)
|
servers = get_config(args.config)
|
||||||
authorized_users = get_authorized_users(args.users_list)
|
authorized_users = get_authorized_users(args.users_list)
|
||||||
deactivate_security = args.deactivate_security
|
deactivate_security = args.deactivate_security
|
||||||
ASCIIColors.red("Ollama Proxy server")
|
ASCIIColors.red("Ollama Proxy server")
|
||||||
@@ -57,17 +67,17 @@ def main():
|
|||||||
class RequestHandler(BaseHTTPRequestHandler):
|
class RequestHandler(BaseHTTPRequestHandler):
|
||||||
def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""):
|
def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""):
|
||||||
log_file_path = Path(args.log_path)
|
log_file_path = Path(args.log_path)
|
||||||
|
|
||||||
if not log_file_path.exists():
|
if not log_file_path.exists():
|
||||||
with open(log_file_path, mode='w', newline='') as csvfile:
|
with open(log_file_path, mode='w', newline='') as csvfile:
|
||||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
with open(log_file_path, mode='a', newline='') as csvfile:
|
with open(log_file_path, mode='a', newline='') as csvfile:
|
||||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error}
|
row = {'time_stamp': str(datetime.datetime.now()), 'event': event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error}
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
def _send_response(self, response):
|
def _send_response(self, response):
|
||||||
@@ -105,7 +115,7 @@ def main():
|
|||||||
return False
|
return False
|
||||||
token = auth_header.split(' ')[1]
|
token = auth_header.split(' ')[1]
|
||||||
user, key = token.split(':')
|
user, key = token.split(':')
|
||||||
|
|
||||||
# Check if the user and key are in the list of authorized users
|
# Check if the user and key are in the list of authorized users
|
||||||
if authorized_users.get(user) == key:
|
if authorized_users.get(user) == key:
|
||||||
self.user = user
|
self.user = user
|
||||||
@@ -115,7 +125,7 @@ def main():
|
|||||||
return False
|
return False
|
||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def proxy(self):
|
def proxy(self):
|
||||||
self.user = "unknown"
|
self.user = "unknown"
|
||||||
if not deactivate_security and not self._validate_user_and_key():
|
if not deactivate_security and not self._validate_user_and_key():
|
||||||
@@ -126,16 +136,15 @@ def main():
|
|||||||
if not auth_header or not auth_header.startswith('Bearer '):
|
if not auth_header or not auth_header.startswith('Bearer '):
|
||||||
self.add_access_log_entry(event='rejected', user="unknown", ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
self.add_access_log_entry(event='rejected', user="unknown", ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
||||||
else:
|
else:
|
||||||
token = auth_header.split(' ')[1]
|
token = auth_header.split(' ')[1]
|
||||||
self.add_access_log_entry(event='rejected', user=token, ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
self.add_access_log_entry(event='rejected', user=token, ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
||||||
self.send_response(403)
|
self.send_response(403)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
return
|
return
|
||||||
url = urlparse(self.path)
|
url = urlparse(self.path)
|
||||||
path = url.path
|
path = url.path
|
||||||
get_params = parse_qs(url.query) or {}
|
get_params = parse_qs(url.query) or {}
|
||||||
|
|
||||||
|
|
||||||
if self.command == "POST":
|
if self.command == "POST":
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
post_data = self.rfile.read(content_length)
|
post_data = self.rfile.read(content_length)
|
||||||
@@ -143,20 +152,28 @@ def main():
|
|||||||
else:
|
else:
|
||||||
post_params = {}
|
post_params = {}
|
||||||
|
|
||||||
|
# Find the server with the lowest number of active requests.
|
||||||
# Find the server with the lowest number of queue entries.
|
min_active_server = servers[0]
|
||||||
min_queued_server = servers[0]
|
|
||||||
for server in servers:
|
for server in servers:
|
||||||
cs = server[1]
|
cs = server[1]
|
||||||
if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():
|
if cs['active_requests'] < min_active_server[1]['active_requests']:
|
||||||
min_queued_server = server
|
min_active_server = server
|
||||||
|
|
||||||
# Apply the queuing mechanism only for a specific endpoint.
|
# Apply the queuing mechanism only for a specific endpoint.
|
||||||
if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions':
|
if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions':
|
||||||
que = min_queued_server[1]['queue']
|
cs = min_active_server[1]
|
||||||
client_ip, client_port = self.client_address
|
client_ip, client_port = self.client_address
|
||||||
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
try:
|
||||||
que.put_nowait(1)
|
# Try to acquire the queue slot for this request.
|
||||||
|
cs['queue'].put_nowait(1)
|
||||||
|
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'])
|
||||||
|
except Queue.Full:
|
||||||
|
# If the queue is full, log and return a 503 Service Unavailable response.
|
||||||
|
self.add_access_log_entry(event="gen_error", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'], error="Queue is full")
|
||||||
|
self.send_response(503)
|
||||||
|
self.end_headers()
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
post_data_dict = {}
|
post_data_dict = {}
|
||||||
|
|
||||||
@@ -164,22 +181,19 @@ def main():
|
|||||||
post_data_str = post_data.decode('utf-8')
|
post_data_str = post_data.decode('utf-8')
|
||||||
post_data_dict = json.loads(post_data_str)
|
post_data_dict = json.loads(post_data_str)
|
||||||
|
|
||||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
|
response = requests.request(self.command, cs['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
|
||||||
self._send_response(response)
|
self._send_response(response)
|
||||||
except Exception as ex:
|
|
||||||
self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex)
|
|
||||||
finally:
|
finally:
|
||||||
que.get_nowait()
|
cs['queue'].get_nowait()
|
||||||
self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
self.add_access_log_entry(event="gen_done", user=self.user, ip_address=client_ip, access="Authorized", server=min_active_server[0], nb_queued_requests_on_server=cs['active_requests'])
|
||||||
else:
|
else:
|
||||||
# For other endpoints, just mirror the request.
|
# For other endpoints, just mirror the request.
|
||||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
|
response = requests.request(self.command, min_active_server[1]['url'] + path, params=get_params, data=post_params)
|
||||||
self._send_response(response)
|
self._send_response(response)
|
||||||
|
|
||||||
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
|
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
print('Starting server')
|
print('Starting server')
|
||||||
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
|
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
|
||||||
print(f'Running server on port {args.port}')
|
print(f'Running server on port {args.port}')
|
||||||
|
Reference in New Issue
Block a user