mirror of
https://github.com/ParisNeo/ollama_proxy_server.git
synced 2025-09-06 05:12:14 +00:00
Compare commits
8 Commits
v17.1.0
...
boosted_ve
Author | SHA1 | Date | |
---|---|---|---|
![]() |
04ba02154e | ||
![]() |
71960644c3 | ||
![]() |
f0ee16e49b | ||
![]() |
db0e80bb9e | ||
![]() |
6eb7aaeec9 | ||
![]() |
98ae367b28 | ||
![]() |
86dcbee930 | ||
![]() |
47ac57b3ed |
@@ -1,6 +1,14 @@
|
||||
"""
|
||||
project: ollama_proxy_server
|
||||
file: add_user.py
|
||||
author: ParisNeo (Saifeddine ALOUI)
|
||||
description: A utility to add users to the authorized_users.txt file for the Ollama Proxy Server.
|
||||
license: Apache 2.0
|
||||
repository: https://github.com/ParisNeo/ollama_proxy_server
|
||||
"""
|
||||
|
||||
import sys
|
||||
import random
|
||||
from getpass import getuser
|
||||
from pathlib import Path
|
||||
|
||||
def generate_key(length=10):
|
||||
|
11
ollama_proxy_server/authorized_users.txt.example
Normal file
11
ollama_proxy_server/authorized_users.txt.example
Normal file
@@ -0,0 +1,11 @@
|
||||
# Example authorized users file for Ollama Proxy Server
|
||||
# Project: ollama_proxy_server
|
||||
# Author: ParisNeo (Saifeddine ALOUI)
|
||||
# License: Apache 2.0
|
||||
# Repository: https://github.com/ParisNeo/ollama_proxy_server
|
||||
# Copy this file to authorized_users.txt and edit the entries to add your users and keys.
|
||||
# Format: username:key
|
||||
|
||||
# Example user entries:
|
||||
alice:abc123!@#XYZ
|
||||
bob:K9$mPq&*vL
|
22
ollama_proxy_server/config.ini.example
Normal file
22
ollama_proxy_server/config.ini.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# Example configuration file for Ollama Proxy Server
|
||||
# Copy this file to config.ini and edit the values to match your environment.
|
||||
|
||||
# Section for backend server URLs
|
||||
# Each server should have its own section, e.g., [server0], [server1], etc.
|
||||
# The 'url' key specifies the URL of the backend Ollama server.
|
||||
[server0]
|
||||
url = http://localhost:11434
|
||||
|
||||
# Add additional servers as needed, e.g.:
|
||||
# [server1]
|
||||
# url = http://another-server:11434
|
||||
|
||||
# Section for logging configuration
|
||||
[Logging]
|
||||
# log_path: the path to the access log file (ensure the application has write permissions)
|
||||
log_path = access_log.txt
|
||||
|
||||
# Section for user management
|
||||
[Users]
|
||||
# users_list: the path to the file containing authorized users and their keys
|
||||
users_list = authorized_users.txt
|
@@ -1,8 +1,10 @@
|
||||
"""
|
||||
project: ollama_proxy_server
|
||||
file: main.py
|
||||
author: ParisNeo
|
||||
description: This is a proxy server that adds a security layer to one or multiple ollama servers and routes the requests to the right server in order to minimize the charge of the server.
|
||||
author: ParisNeo (Saifeddine ALOUI)
|
||||
description: A proxy server adding a security layer to one or multiple Ollama servers, routing requests to minimize server load.
|
||||
license: Apache 2.0
|
||||
repository: https://github.com/ParisNeo/ollama_proxy_server
|
||||
"""
|
||||
|
||||
import configparser
|
||||
@@ -17,70 +19,95 @@ from ascii_colors import ASCIIColors
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import datetime
|
||||
import threading
|
||||
import shutil
|
||||
|
||||
def get_config(filename):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(filename)
|
||||
return [(name, {'url': config[name]['url'], 'queue': Queue()}) for name in config.sections()]
|
||||
return [(name, {
|
||||
'url': config[name]['url'],
|
||||
'session': requests.Session(),
|
||||
'ongoing_requests': 0,
|
||||
'lock': threading.Lock()
|
||||
}) for name in config.sections()]
|
||||
|
||||
# Read the authorized users and their keys from a file
|
||||
def get_authorized_users(filename):
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.readlines()
|
||||
authorized_users = {}
|
||||
for line in lines:
|
||||
if line=="":
|
||||
if line.strip() == "":
|
||||
continue
|
||||
try:
|
||||
user, key = line.strip().split(':')
|
||||
authorized_users[user] = key
|
||||
except:
|
||||
ASCIIColors.red(f"User entry broken:{line.strip()}")
|
||||
ASCIIColors.red(f"User entry broken: {line.strip()}")
|
||||
return authorized_users
|
||||
|
||||
|
||||
def log_writer(log_queue, log_file_path):
|
||||
with open(log_file_path, mode='a', newline='') as csvfile:
|
||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
if csvfile.tell() == 0:
|
||||
writer.writeheader()
|
||||
while True:
|
||||
log_entry = log_queue.get()
|
||||
if log_entry is None: # Signal to exit
|
||||
break
|
||||
writer.writerow(log_entry)
|
||||
csvfile.flush()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', default="config.ini", help='Path to the authorized users list')
|
||||
parser = argparse.ArgumentParser(description="Ollama Proxy Server by ParisNeo")
|
||||
parser.add_argument('--config', default="config.ini", help='Path to the config file')
|
||||
parser.add_argument('--log_path', default="access_log.txt", help='Path to the access log file')
|
||||
parser.add_argument('--users_list', default="authorized_users.txt", help='Path to the config file')
|
||||
parser.add_argument('--users_list', default="authorized_users.txt", help='Path to the authorized users list')
|
||||
parser.add_argument('--port', type=int, default=8000, help='Port number for the server')
|
||||
parser.add_argument('-d', '--deactivate_security', action='store_true', help='Deactivates security')
|
||||
args = parser.parse_args()
|
||||
servers = get_config(args.config)
|
||||
servers = get_config(args.config)
|
||||
authorized_users = get_authorized_users(args.users_list)
|
||||
deactivate_security = args.deactivate_security
|
||||
ASCIIColors.red("Ollama Proxy server")
|
||||
ASCIIColors.red("Author: ParisNeo")
|
||||
ASCIIColors.red("Ollama Proxy Server")
|
||||
ASCIIColors.red("Author: ParisNeo (Saifeddine ALOUI)")
|
||||
ASCIIColors.red("License: Apache 2.0")
|
||||
ASCIIColors.red("Repository: https://github.com/ParisNeo/ollama_proxy_server")
|
||||
|
||||
global log_queue
|
||||
log_queue = Queue()
|
||||
log_file_path = Path(args.log_path)
|
||||
if not log_file_path.exists() or log_file_path.stat().st_size == 0:
|
||||
with open(log_file_path, mode='w', newline='') as csvfile:
|
||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
log_writer_thread = threading.Thread(target=log_writer, args=(log_queue, log_file_path))
|
||||
log_writer_thread.daemon = True
|
||||
log_writer_thread.start()
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
def add_access_log_entry(self, event, user, ip_address, access, server, nb_queued_requests_on_server, error=""):
|
||||
log_file_path = Path(args.log_path)
|
||||
|
||||
if not log_file_path.exists():
|
||||
with open(log_file_path, mode='w', newline='') as csvfile:
|
||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
with open(log_file_path, mode='a', newline='') as csvfile:
|
||||
fieldnames = ['time_stamp', 'event', 'user_name', 'ip_address', 'access', 'server', 'nb_queued_requests_on_server', 'error']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
row = {'time_stamp': str(datetime.datetime.now()), 'event':event, 'user_name': user, 'ip_address': ip_address, 'access': access, 'server': server, 'nb_queued_requests_on_server': nb_queued_requests_on_server, 'error': error}
|
||||
writer.writerow(row)
|
||||
log_entry = {
|
||||
'time_stamp': str(datetime.datetime.now()),
|
||||
'event': event,
|
||||
'user_name': user,
|
||||
'ip_address': ip_address,
|
||||
'access': access,
|
||||
'server': server,
|
||||
'nb_queued_requests_on_server': nb_queued_requests_on_server,
|
||||
'error': error
|
||||
}
|
||||
log_queue.put(log_entry)
|
||||
|
||||
def _send_response(self, response):
|
||||
self.send_response(response.status_code)
|
||||
for key, value in response.headers.items():
|
||||
if key.lower() not in ['content-length', 'transfer-encoding', 'content-encoding']:
|
||||
self.send_header(key, value)
|
||||
self.send_header(key, value)
|
||||
self.end_headers()
|
||||
|
||||
try:
|
||||
# Read the full content to avoid chunking issues
|
||||
content = response.content
|
||||
self.wfile.write(content)
|
||||
shutil.copyfileobj(response.raw, self.wfile)
|
||||
self.wfile.flush()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
@@ -99,14 +126,11 @@ def main():
|
||||
|
||||
def _validate_user_and_key(self):
|
||||
try:
|
||||
# Extract the bearer token from the headers
|
||||
auth_header = self.headers.get('Authorization')
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return False
|
||||
token = auth_header.split(' ')[1]
|
||||
user, key = token.split(':')
|
||||
|
||||
# Check if the user and key are in the list of authorized users
|
||||
token = auth_header.split(' ')[1]
|
||||
user, key = token.split(':')
|
||||
if authorized_users.get(user) == key:
|
||||
self.user = user
|
||||
return True
|
||||
@@ -115,75 +139,73 @@ def main():
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def proxy(self):
|
||||
self.user = "unknown"
|
||||
if not deactivate_security and not self._validate_user_and_key():
|
||||
ASCIIColors.red(f'User is not authorized')
|
||||
client_ip, client_port = self.client_address
|
||||
# Extract the bearer token from the headers
|
||||
ASCIIColors.red('User is not authorized')
|
||||
client_ip, _ = self.client_address
|
||||
auth_header = self.headers.get('Authorization')
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
self.add_access_log_entry(event='rejected', user="unknown", ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
||||
else:
|
||||
token = auth_header.split(' ')[1]
|
||||
token = auth_header.split(' ')[1]
|
||||
self.add_access_log_entry(event='rejected', user=token, ip_address=client_ip, access="Denied", server="None", nb_queued_requests_on_server=-1, error="Authentication failed")
|
||||
self.send_response(403)
|
||||
self.end_headers()
|
||||
return
|
||||
return
|
||||
|
||||
url = urlparse(self.path)
|
||||
path = url.path
|
||||
get_params = parse_qs(url.query) or {}
|
||||
|
||||
|
||||
post_params = {}
|
||||
if self.command == "POST":
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
post_params = post_data# parse_qs(post_data.decode('utf-8'))
|
||||
else:
|
||||
post_params = {}
|
||||
post_params = self.rfile.read(content_length)
|
||||
|
||||
min_queued_server = min(servers, key=lambda s: s[1]['ongoing_requests'])
|
||||
|
||||
# Find the server with the lowest number of queue entries.
|
||||
min_queued_server = servers[0]
|
||||
for server in servers:
|
||||
cs = server[1]
|
||||
if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():
|
||||
min_queued_server = server
|
||||
|
||||
# Apply the queuing mechanism only for a specific endpoint.
|
||||
if path == '/api/generate' or path == '/api/chat' or path == '/v1/chat/completions':
|
||||
que = min_queued_server[1]['queue']
|
||||
client_ip, client_port = self.client_address
|
||||
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
||||
que.put_nowait(1)
|
||||
if path in ['/api/generate', '/api/chat', '/v1/chat/completions']:
|
||||
with min_queued_server[1]['lock']:
|
||||
min_queued_server[1]['ongoing_requests'] += 1
|
||||
client_ip, _ = self.client_address
|
||||
self.add_access_log_entry(event="gen_request", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'])
|
||||
try:
|
||||
post_data_dict = {}
|
||||
|
||||
if isinstance(post_data, bytes):
|
||||
post_data_str = post_data.decode('utf-8')
|
||||
post_data_dict = json.loads(post_data_str)
|
||||
|
||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params, stream=post_data_dict.get("stream", False))
|
||||
post_data_dict = json.loads(post_params.decode('utf-8')) if isinstance(post_params, bytes) else {}
|
||||
response = min_queued_server[1]['session'].request(
|
||||
self.command,
|
||||
min_queued_server[1]['url'] + path,
|
||||
params=get_params,
|
||||
data=post_params,
|
||||
stream=post_data_dict.get("stream", False)
|
||||
)
|
||||
self._send_response(response)
|
||||
except Exception as ex:
|
||||
self.add_access_log_entry(event="gen_error",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize(),error=ex)
|
||||
self.add_access_log_entry(event="gen_error", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'], error=str(ex))
|
||||
finally:
|
||||
que.get_nowait()
|
||||
self.add_access_log_entry(event="gen_done",user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=que.qsize())
|
||||
with min_queued_server[1]['lock']:
|
||||
min_queued_server[1]['ongoing_requests'] -= 1
|
||||
self.add_access_log_entry(event="gen_done", user=self.user, ip_address=client_ip, access="Authorized", server=min_queued_server[0], nb_queued_requests_on_server=min_queued_server[1]['ongoing_requests'])
|
||||
else:
|
||||
# For other endpoints, just mirror the request.
|
||||
response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params, data=post_params)
|
||||
response = min_queued_server[1]['session'].request(
|
||||
self.command,
|
||||
min_queued_server[1]['url'] + path,
|
||||
params=get_params,
|
||||
data=post_params
|
||||
)
|
||||
self._send_response(response)
|
||||
|
||||
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
pass
|
||||
|
||||
|
||||
print('Starting server')
|
||||
server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.
|
||||
server = ThreadedHTTPServer(('', args.port), RequestHandler)
|
||||
print(f'Running server on port {args.port}')
|
||||
server.serve_forever()
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
log_queue.put(None) # Signal log_writer to exit
|
||||
server.server_close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -1,8 +1,2 @@
|
||||
ascii-colors==0.2.2
|
||||
certifi==2024.7.4
|
||||
charset-normalizer==3.3.2
|
||||
configparser==6.0.1
|
||||
idna==3.6
|
||||
queues==0.6.3
|
||||
requests==2.31.0
|
||||
urllib3==2.2.1
|
||||
requests>=2.31.0
|
||||
ascii_colors>=0.5.2
|
||||
|
7
setup.py
7
setup.py
@@ -6,26 +6,23 @@ import setuptools
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
def read_requirements(path: Union[str, Path]):
|
||||
with open(path, "r") as file:
|
||||
return file.read().splitlines()
|
||||
|
||||
|
||||
requirements = read_requirements("requirements.txt")
|
||||
requirements_dev = read_requirements("requirements_dev.txt")
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="ollama_proxy_server",
|
||||
version="7.1.0",
|
||||
author="Saifeddine ALOUI (ParisNeo)",
|
||||
author_email="aloui.saifeddine@gmail.com",
|
||||
description="A fastapi server for petals decentralized text generation",
|
||||
description="A proxy server adding a security layer to Ollama servers, routing requests to minimize server load",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/ParisNeo/ollama_proxy_server",
|
||||
packages=setuptools.find_packages(),
|
||||
packages=setuptools.find_packages(),
|
||||
include_package_data=True,
|
||||
install_requires=requirements,
|
||||
entry_points={
|
||||
|
Reference in New Issue
Block a user