241 lines
8.2 KiB
Python
241 lines
8.2 KiB
Python
"""
|
|
SFTP Server implementation using asyncssh.
|
|
Shares the same authentication system and file handling logic as the FTP server.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import asyncssh
|
|
|
|
from utils.connect import file_management
|
|
from utils.database.connection import connetti_db_async
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ASESFTPServer(asyncssh.SFTPServer):
|
|
"""Custom SFTP server that handles file uploads with the same logic as FTP server."""
|
|
|
|
def __init__(self, chan):
|
|
"""Initialize SFTP server with channel."""
|
|
super().__init__(chan)
|
|
# Get config from connection (set during authentication)
|
|
self.cfg = chan.get_connection()._cfg
|
|
|
|
async def close(self):
|
|
"""Called when SFTP session is closed."""
|
|
logger.info(f"SFTP session closed for user: {self._chan.get_connection().get_extra_info('username')}")
|
|
await super().close()
|
|
|
|
|
|
class ASESSHServer(asyncssh.SSHServer):
|
|
"""Custom SSH server for SFTP authentication using database."""
|
|
|
|
def __init__(self, cfg):
|
|
"""Initialize SSH server with configuration."""
|
|
self.cfg = cfg
|
|
self.user_home_dirs = {} # Store user home directories after authentication
|
|
super().__init__()
|
|
|
|
def connection_made(self, conn):
|
|
"""Called when connection is established."""
|
|
# Store config in connection for later use
|
|
conn._cfg = self.cfg
|
|
conn._ssh_server = self # Store reference to server for accessing user_home_dirs
|
|
logger.info(f"SSH connection from {conn.get_extra_info('peername')[0]}")
|
|
|
|
def connection_lost(self, exc):
|
|
"""Called when connection is lost."""
|
|
if exc:
|
|
logger.error(f"SSH connection lost: {exc}")
|
|
|
|
async def validate_password(self, username, password):
|
|
"""
|
|
Validate user credentials against database.
|
|
Same logic as DatabaseAuthorizer for FTP.
|
|
"""
|
|
from hashlib import sha256
|
|
|
|
# Hash the provided password
|
|
password_hash = sha256(password.encode("UTF-8")).hexdigest()
|
|
|
|
# Check if user is admin
|
|
if username == self.cfg.adminuser[0]:
|
|
if self.cfg.adminuser[1] == password_hash:
|
|
# Store admin home directory
|
|
self.user_home_dirs[username] = self.cfg.adminuser[2]
|
|
logger.info(f"Admin user '{username}' authenticated successfully (home: {self.cfg.adminuser[2]})")
|
|
return True
|
|
else:
|
|
logger.warning(f"Failed admin login attempt for user: {username}")
|
|
return False
|
|
|
|
# For regular users, check database
|
|
try:
|
|
conn = await connetti_db_async(self.cfg)
|
|
cur = await conn.cursor()
|
|
|
|
# Query user from database
|
|
await cur.execute(
|
|
f"SELECT ftpuser, hash, virtpath, perm, disabled_at FROM {self.cfg.dbname}.{self.cfg.dbusertable} WHERE ftpuser = %s",
|
|
(username,)
|
|
)
|
|
|
|
result = await cur.fetchone()
|
|
await cur.close()
|
|
conn.close()
|
|
|
|
if not result:
|
|
logger.warning(f"SFTP login attempt for non-existent user: {username}")
|
|
return False
|
|
|
|
ftpuser, stored_hash, virtpath, perm, disabled_at = result
|
|
|
|
# Check if user is disabled
|
|
if disabled_at is not None:
|
|
logger.warning(f"SFTP login attempt for disabled user: {username}")
|
|
return False
|
|
|
|
# Verify password
|
|
if stored_hash != password_hash:
|
|
logger.warning(f"Invalid password for SFTP user: {username}")
|
|
return False
|
|
|
|
# Authentication successful - ensure user directory exists
|
|
try:
|
|
Path(virtpath).mkdir(parents=True, exist_ok=True)
|
|
except Exception as e:
|
|
logger.error(f"Failed to create directory for user {username}: {e}")
|
|
return False
|
|
|
|
# Store the user's home directory for chroot
|
|
self.user_home_dirs[username] = virtpath
|
|
|
|
logger.info(f"Successful SFTP login for user: {username} (home: {virtpath})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database error during SFTP authentication for user {username}: {e}", exc_info=True)
|
|
return False
|
|
|
|
def password_auth_supported(self):
|
|
"""Enable password authentication."""
|
|
return True
|
|
|
|
def begin_auth(self, username):
|
|
"""Called when authentication begins."""
|
|
logger.debug(f"Authentication attempt for user: {username}")
|
|
return True
|
|
|
|
|
|
class SFTPFileHandler(asyncssh.SFTPServer):
|
|
"""Extended SFTP server with file upload handling."""
|
|
|
|
def __init__(self, chan):
|
|
super().__init__(chan, chroot=self._get_user_home(chan))
|
|
self.cfg = chan.get_connection()._cfg
|
|
self._open_files = {} # Track open files for processing
|
|
|
|
@staticmethod
|
|
def _get_user_home(chan):
|
|
"""Get the home directory for the authenticated user."""
|
|
conn = chan.get_connection()
|
|
username = conn.get_extra_info('username')
|
|
ssh_server = getattr(conn, '_ssh_server', None)
|
|
|
|
if ssh_server and username in ssh_server.user_home_dirs:
|
|
return ssh_server.user_home_dirs[username]
|
|
|
|
# Fallback for admin user
|
|
if hasattr(conn, '_cfg') and username == conn._cfg.adminuser[0]:
|
|
return conn._cfg.adminuser[2]
|
|
|
|
return None
|
|
|
|
def open(self, path, pflags, attrs):
|
|
"""Track files being opened for writing."""
|
|
result = super().open(path, pflags, attrs)
|
|
|
|
# If file is opened for writing (pflags contains FXF_WRITE)
|
|
if pflags & 0x02: # FXF_WRITE flag
|
|
real_path = self.map_path(path)
|
|
# Convert bytes to str if necessary
|
|
if isinstance(real_path, bytes):
|
|
real_path = real_path.decode('utf-8')
|
|
self._open_files[result] = real_path
|
|
logger.debug(f"File opened for writing: {real_path}")
|
|
|
|
return result
|
|
|
|
async def close(self, file_obj):
|
|
"""Process file after it's closed."""
|
|
# Call parent close first (this doesn't return anything useful)
|
|
result = super().close(file_obj)
|
|
|
|
# Check if this file was tracked
|
|
if file_obj in self._open_files:
|
|
filepath = self._open_files.pop(file_obj)
|
|
|
|
# Process CSV files
|
|
if filepath.lower().endswith('.csv'):
|
|
try:
|
|
logger.info(f"CSV file closed after upload via SFTP: {filepath}")
|
|
|
|
# Get username
|
|
username = self._chan.get_connection().get_extra_info('username')
|
|
|
|
# Create a mock handler object with required attributes
|
|
mock_handler = type('obj', (object,), {
|
|
'cfg': self.cfg,
|
|
'username': username
|
|
})()
|
|
|
|
# Call the file processing function
|
|
from utils.connect import file_management
|
|
await file_management.on_file_received_async(mock_handler, filepath)
|
|
except Exception as e:
|
|
logger.error(f"Error processing SFTP file on close: {e}", exc_info=True)
|
|
|
|
return result
|
|
|
|
async def exit(self):
|
|
"""Handle session close."""
|
|
await super().exit()
|
|
|
|
# Note: File processing is handled in close() method, not here
|
|
# This avoids double-processing when both close and rename are called
|
|
|
|
|
|
async def start_sftp_server(cfg, host='0.0.0.0', port=22):
|
|
"""
|
|
Start SFTP server.
|
|
|
|
Args:
|
|
cfg: Configuration object
|
|
host: Host to bind to
|
|
port: Port to bind to
|
|
|
|
Returns:
|
|
asyncssh server object
|
|
"""
|
|
logger.info(f"Starting SFTP server on {host}:{port}")
|
|
|
|
# Create SSH server
|
|
ssh_server = ASESSHServer(cfg)
|
|
|
|
# Start asyncssh server
|
|
server = await asyncssh.create_server(
|
|
lambda: ssh_server,
|
|
host,
|
|
port,
|
|
server_host_keys=['/app/ssh_host_key'], # You'll need to generate this
|
|
sftp_factory=SFTPFileHandler,
|
|
)
|
|
|
|
logger.info(f"SFTP server started successfully on {host}:{port}")
|
|
logger.info(f"Database connection: {cfg.dbuser}@{cfg.dbhost}:{cfg.dbport}/{cfg.dbname}")
|
|
|
|
return server
|