8 Commits

Author SHA1 Message Date
Saifeddine ALOUI
04ba02154e Update setup.py 2025-03-26 21:46:35 +01:00
Saifeddine ALOUI
71960644c3 Update requirements.txt 2025-03-26 21:45:12 +01:00
Saifeddine ALOUI
f0ee16e49b Update authorized_users.txt.example 2025-03-26 21:42:25 +01:00
Saifeddine ALOUI
db0e80bb9e Update add_user.py 2025-03-26 21:41:29 +01:00
Saifeddine ALOUI
6eb7aaeec9 Update main.py 2025-03-26 21:41:03 +01:00
Saifeddine ALOUI
98ae367b28 Create authorized_users.txt.example 2025-03-26 21:34:45 +01:00
Saifeddine ALOUI
86dcbee930 Create config.ini.example 2025-03-26 19:01:51 +01:00
Saifeddine ALOUI
47ac57b3ed Update main.py 2025-03-26 18:55:36 +01:00
6 changed files with 145 additions and 91 deletions

View File

@@ -1,6 +1,14 @@
"""
project: ollama_proxy_server
file: add_user.py
author: ParisNeo (Saifeddine ALOUI)
description: A utility to add users to the authorized_users.txt file for the Ollama Proxy Server.
license: Apache 2.0
repository: https://github.com/ParisNeo/ollama_proxy_server
"""
import sys
import random
from getpass import getuser
from pathlib import Path
def generate_key(length=10):

View File

@@ -0,0 +1,11 @@
# Example authorized users file for Ollama Proxy Server
# Project: ollama_proxy_server
# Author: ParisNeo (Saifeddine ALOUI)
# License: Apache 2.0
# Repository: https://github.com/ParisNeo/ollama_proxy_server
# Copy this file to authorized_users.txt and edit the entries to add your users and keys.
# Format: username:key
# Example user entries:
alice:abc123!@#XYZ
bob:K9$mPq&*vL

View File

@@ -0,0 +1,22 @@
# Example configuration file for Ollama Proxy Server
# Copy this file to config.ini and edit the values to match your environment.
# Section for backend server URLs
# Each server should have its own section, e.g., [server0], [server1], etc.
# The 'url' key specifies the URL of the backend Ollama server.
[server0]
url = http://localhost:11434
# Add additional servers as needed, e.g.:
# [server1]
# url = http://another-server:11434
# Section for logging configuration
[Logging]
# log_path: the path to the access log file (ensure the application has write permissions)
log_path = access_log.txt
# Section for user management
[Users]
# users_list: the path to the file containing authorized users and their keys
users_list = authorized_users.txt

View File

@@ -1,8 +1,10 @@
"""
project: ollama_proxy_server
file: main.py
author: ParisNeo
description: This is a proxy server that adds a security layer to one or multiple ollama servers and routes the requests to the right server in order to minimize the charge of the server.
author: ParisNeo (Saifeddine ALOUI)
description: A proxy server adding a security layer to one or multiple Ollama servers, routing requests to minimize server load.
license: Apache 2.0
repository: https://github.com/ParisNeo/ollama_proxy_server
"""
import configparser
@@ -17,70 +19,95 @@ from ascii_colors import ASCIIColors
from pathlib import Path
import csv
import datetime
import threading
import shutil
def get_config(filename):
config = configparser.ConfigParser()
config.read(filename)
return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()]
return [(name, {
'url': config[name]['url'],
'session': requests.Session(),
'ongoing_requests': 0,
'lock': threading.Lock()
}) for name in config.sections()]
# Read the authorized users and their keys from a file
def get_authorized_users(filename):
with open(filename, 'r') as f:
lines = f.readlines()
authorized_users = {}
for line in lines:
if line=="":
if line.strip() == "":
continue
try:
user, key = line.strip().split(':')
authorized_users[user] = key
except:
ASCIIColors.red(f"User entry broken:{line.strip()}")
ASCIIColors.red(f"User entry broken: {line.strip()}")
return authorized_users
def log_writer(log_queue, log_file_path):
with open(log_file_path, mode='a', newline='') as csvfile:
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if csvfile.tell() == 0:
writer.writeheader()
while True:
log_entry = log_queue.get()
if log_entry is None: # Signal to exit
break
writer.writerow(log_entry)
csvfile.flush()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
parser = argparse.ArgumentParser(description="Ollama Proxy Server by ParisNeo")
parser.add_argument('--config', default="config.ini", help='Path to the config file')
parser.add_argument('--log_path', default="access_log.txt", help='Path to the access log file')
parser.add_argument('--users_list', default="authorized_users.txt", help='Path to the config file')
parser.add_argument('--users_list', default="authorized_users.txt", help='Path to the authorized users list')
parser.add_argument('--port', type=int, default=8000, help='Port number for the server')
parser.add_argument('-d', '--deactivate_security', action='store_true', help='Deactivates security')
args = parser.parse_args()
servers = get_config(args.config)
servers = get_config(args.config)
authorized_users = get_authorized_users(args.users_list)
deactivate_security = args.deactivate_security
ASCIIColors.red("Ollama Proxy server")
ASCIIColors.red("Author: ParisNeo")
ASCIIColors.red("Ollama Proxy Server")
ASCIIColors.red("Author: ParisNeo (Saifeddine ALOUI)")
ASCIIColors.red("License: Apache 2.0")
ASCIIColors.red("Repository: https://github.com/ParisNeo/ollama_proxy_server")
global log_queue
log_queue = Queue()
log_file_path = Path(args.log_path)
if not log_file_path.exists() or log_file_path.stat().st_size == 0:
with open(log_file_path, mode='w', newline='') as csvfile:
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
log_writer_thread = threading.Thread(target=log_writer, args=(log_queue, log_file_path))
log_writer_thread.daemon = True
log_writer_thread.start()
class RequestHandler(BaseHTTPRequestHandler):
def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""):
log_file_path = Path(args.log_path)
if not log_file_path.exists():
with open(log_file_path, mode='w', newline='') as csvfile:
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with open(log_file_path, mode='a', newline='') as csvfile:
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error}
writer.writerow(row)
log_entry = {
'time_stamp': str(datetime.datetime.now()),
'event': event,
'user_name': user,
'ip_address': ip_address,
'access': access,
'server': server,
'nb_queued_requests_on_server': nb_queued_requests_on_server,
'error': error
}
log_queue.put(log_entry)
def _send_response(self, response):
self.send_response(response.status_code)
for key, value in response.headers.items():
if key.lower() not in ['content-length', 'transfer-encoding', 'content-encoding']:
self.send_header(key, value)
self.send_header(key, value)
self.end_headers()
try:
# Read the full content to avoid chunking issues
content = response.content
self.wfile.write(content)
shutil.copyfileobj(response.raw, self.wfile)
self.wfile.flush()
except BrokenPipeError:
pass
@@ -99,14 +126,11 @@ def main():
def _validate_user_and_key(self):
try:
# Extract the bearer token from the headers
auth_header = self.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return False
token = auth_header.split(' ')[1]
user, key = token.split(':')
# Check if the user and key are in the list of authorized users
token = auth_header.split(' ')[1]
user, key = token.split(':')
if authorized_users.get(user) == key:
self.user = user
return True
@@ -115,75 +139,73 @@ def main():
return False
except:
return False
def proxy(self):
self.user = "unknown"
if not deactivate_security and not self._validate_user_and_key():
ASCIIColors.red(f'User is not authorized')
client_ip, client_port = self.client_address
# Extract the bearer token from the headers
ASCIIColors.red('User is not authorized')
client_ip, _ = self.client_address
auth_header = self.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
self.add_access_log_entry(event='rejected', user="unknown", ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
else:
token = auth_header.split(' ')[1]
token = auth_header.split(' ')[1]
self.add_access_log_entry(event='rejected', user=token, ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
self.send_response(403)
self.end_headers()
return
return
url = urlparse(self.path)
path = url.path
get_params = parse_qs(url.query) or {}
post_params = {}
if self.command == "POST":
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
post_params = post_data# parse_qs(post_data.decode('utf-8'))
else:
post_params = {}
post_params = self.rfile.read(content_length)
min_queued_server = min(servers, key=lambda s: s[1]['ongoing_requests'])
# Find the server with the lowest number of queue entries.
min_queued_server = servers[0]
for server in servers:
cs = server[1]
if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():
min_queued_server = server
# Apply the queuing mechanism only for a specific endpoint.
if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions':
que = min_queued_server[1]['queue']
client_ip, client_port = self.client_address
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
que.put_nowait(1)
if path in ['/api/generate', '/api/chat', '/v1/chat/completions']:
with min_queued_server[1]['lock']:
min_queued_server[1]['ongoing_requests'] += 1
client_ip, _ = self.client_address
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'])
try:
post_data_dict = {}
if isinstance(post_data, bytes):
post_data_str = post_data.decode('utf-8')
post_data_dict = json.loads(post_data_str)
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
post_data_dict = json.loads(post_params.decode('utf-8')) if isinstance(post_params, bytes) else {}
response = min_queued_server[1]['session'].request(
self.command,
min_queued_server[1]['url'] + path,
params=get_params,
data=post_params,
stream=post_data_dict.get("stream", False)
)
self._send_response(response)
except Exception as ex:
self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex)
self.add_access_log_entry(event="gen_error", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'], error=str(ex))
finally:
que.get_nowait()
self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
with min_queued_server[1]['lock']:
min_queued_server[1]['ongoing_requests'] -= 1
self.add_access_log_entry(event="gen_done", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'])
else:
# For other endpoints, just mirror the request.
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
response = min_queued_server[1]['session'].request(
self.command,
min_queued_server[1]['url'] + path,
params=get_params,
data=post_params
)
self._send_response(response)
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
pass
print('Starting server')
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
server = ThreadedHTTPServer(('', args.port), RequestHandler)
print(f'Running server on port {args.port}')
server.serve_forever()
try:
server.serve_forever()
except KeyboardInterrupt:
log_queue.put(None) # Signal log_writer to exit
server.server_close()
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,2 @@
ascii-colors==0.2.2
certifi==2024.7.4
charset-normalizer==3.3.2
configparser==6.0.1
idna==3.6
queues==0.6.3
requests==2.31.0
urllib3==2.2.1
requests>=2.31.0
ascii_colors>=0.5.2

View File

@@ -6,26 +6,23 @@ import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
def read_requirements(path: Union[str, Path]):
with open(path, "r") as file:
return file.read().splitlines()
requirements = read_requirements("requirements.txt")
requirements_dev = read_requirements("requirements_dev.txt")
setuptools.setup(
name="ollama_proxy_server",
version="7.1.0",
author="Saifeddine ALOUI (ParisNeo)",
author_email="aloui.saifeddine@gmail.com",
description="A fastapi server for petals decentralized text generation",
description="A proxy server adding a security layer to Ollama servers, routing requests to minimize server load",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/ParisNeo/ollama_proxy_server",
packages=setuptools.find_packages(),
packages=setuptools.find_packages(),
include_package_data=True,
install_requires=requirements,
entry_points={