Files
ASE/src/utils/servers/sftp_server.py
2025-11-03 18:54:49 +01:00

241 lines
8.2 KiB
Python

"""
SFTP Server implementation using asyncssh.
Shares the same authentication system and file handling logic as the FTP server.
"""
import asyncio
import logging
from pathlib import Path
import asyncssh
from utils.connect import file_management
from utils.database.connection import connetti_db_async
logger = logging.getLogger(__name__)
class ASESFTPServer(asyncssh.SFTPServer):
"""Custom SFTP server that handles file uploads with the same logic as FTP server."""
def __init__(self, chan):
"""Initialize SFTP server with channel."""
super().__init__(chan)
# Get config from connection (set during authentication)
self.cfg = chan.get_connection()._cfg
async def close(self):
"""Called when SFTP session is closed."""
logger.info(f"SFTP session closed for user: {self._chan.get_connection().get_extra_info('username')}")
await super().close()
class ASESSHServer(asyncssh.SSHServer):
"""Custom SSH server for SFTP authentication using database."""
def __init__(self, cfg):
"""Initialize SSH server with configuration."""
self.cfg = cfg
self.user_home_dirs = {} # Store user home directories after authentication
super().__init__()
def connection_made(self, conn):
"""Called when connection is established."""
# Store config in connection for later use
conn._cfg = self.cfg
conn._ssh_server = self # Store reference to server for accessing user_home_dirs
logger.info(f"SSH connection from {conn.get_extra_info('peername')[0]}")
def connection_lost(self, exc):
"""Called when connection is lost."""
if exc:
logger.error(f"SSH connection lost: {exc}")
async def validate_password(self, username, password):
"""
Validate user credentials against database.
Same logic as DatabaseAuthorizer for FTP.
"""
from hashlib import sha256
# Hash the provided password
password_hash = sha256(password.encode("UTF-8")).hexdigest()
# Check if user is admin
if username == self.cfg.adminuser[0]:
if self.cfg.adminuser[1] == password_hash:
# Store admin home directory
self.user_home_dirs[username] = self.cfg.adminuser[2]
logger.info(f"Admin user '{username}' authenticated successfully (home: {self.cfg.adminuser[2]})")
return True
else:
logger.warning(f"Failed admin login attempt for user: {username}")
return False
# For regular users, check database
try:
conn = await connetti_db_async(self.cfg)
cur = await conn.cursor()
# Query user from database
await cur.execute(
f"SELECT ftpuser, hash, virtpath, perm, disabled_at FROM {self.cfg.dbname}.{self.cfg.dbusertable} WHERE ftpuser = %s",
(username,)
)
result = await cur.fetchone()
await cur.close()
conn.close()
if not result:
logger.warning(f"SFTP login attempt for non-existent user: {username}")
return False
ftpuser, stored_hash, virtpath, perm, disabled_at = result
# Check if user is disabled
if disabled_at is not None:
logger.warning(f"SFTP login attempt for disabled user: {username}")
return False
# Verify password
if stored_hash != password_hash:
logger.warning(f"Invalid password for SFTP user: {username}")
return False
# Authentication successful - ensure user directory exists
try:
Path(virtpath).mkdir(parents=True, exist_ok=True)
except Exception as e:
logger.error(f"Failed to create directory for user {username}: {e}")
return False
# Store the user's home directory for chroot
self.user_home_dirs[username] = virtpath
logger.info(f"Successful SFTP login for user: {username} (home: {virtpath})")
return True
except Exception as e:
logger.error(f"Database error during SFTP authentication for user {username}: {e}", exc_info=True)
return False
def password_auth_supported(self):
"""Enable password authentication."""
return True
def begin_auth(self, username):
"""Called when authentication begins."""
logger.debug(f"Authentication attempt for user: {username}")
return True
class SFTPFileHandler(asyncssh.SFTPServer):
"""Extended SFTP server with file upload handling."""
def __init__(self, chan):
super().__init__(chan, chroot=self._get_user_home(chan))
self.cfg = chan.get_connection()._cfg
self._open_files = {} # Track open files for processing
@staticmethod
def _get_user_home(chan):
"""Get the home directory for the authenticated user."""
conn = chan.get_connection()
username = conn.get_extra_info('username')
ssh_server = getattr(conn, '_ssh_server', None)
if ssh_server and username in ssh_server.user_home_dirs:
return ssh_server.user_home_dirs[username]
# Fallback for admin user
if hasattr(conn, '_cfg') and username == conn._cfg.adminuser[0]:
return conn._cfg.adminuser[2]
return None
def open(self, path, pflags, attrs):
"""Track files being opened for writing."""
result = super().open(path, pflags, attrs)
# If file is opened for writing (pflags contains FXF_WRITE)
if pflags & 0x02: # FXF_WRITE flag
real_path = self.map_path(path)
# Convert bytes to str if necessary
if isinstance(real_path, bytes):
real_path = real_path.decode('utf-8')
self._open_files[result] = real_path
logger.debug(f"File opened for writing: {real_path}")
return result
async def close(self, file_obj):
"""Process file after it's closed."""
# Call parent close first (this doesn't return anything useful)
result = super().close(file_obj)
# Check if this file was tracked
if file_obj in self._open_files:
filepath = self._open_files.pop(file_obj)
# Process CSV files
if filepath.lower().endswith('.csv'):
try:
logger.info(f"CSV file closed after upload via SFTP: {filepath}")
# Get username
username = self._chan.get_connection().get_extra_info('username')
# Create a mock handler object with required attributes
mock_handler = type('obj', (object,), {
'cfg': self.cfg,
'username': username
})()
# Call the file processing function
from utils.connect import file_management
await file_management.on_file_received_async(mock_handler, filepath)
except Exception as e:
logger.error(f"Error processing SFTP file on close: {e}", exc_info=True)
return result
async def exit(self):
"""Handle session close."""
await super().exit()
# Note: File processing is handled in close() method, not here
# This avoids double-processing when both close and rename are called
async def start_sftp_server(cfg, host='0.0.0.0', port=22):
"""
Start SFTP server.
Args:
cfg: Configuration object
host: Host to bind to
port: Port to bind to
Returns:
asyncssh server object
"""
logger.info(f"Starting SFTP server on {host}:{port}")
# Create SSH server
ssh_server = ASESSHServer(cfg)
# Start asyncssh server
server = await asyncssh.create_server(
lambda: ssh_server,
host,
port,
server_host_keys=['/app/ssh_host_key'], # You'll need to generate this
sftp_factory=SFTPFileHandler,
)
logger.info(f"SFTP server started successfully on {host}:{port}")
logger.info(f"Database connection: {cfg.dbuser}@{cfg.dbhost}:{cfg.dbport}/{cfg.dbname}")
return server