fix clean start

This commit is contained in:
2024-12-30 20:01:55 +01:00
parent 4619255a32
commit 689fe55522
2 changed files with 74 additions and 38 deletions

View File

@@ -47,7 +47,8 @@ def init_db():
site TEXT NOT NULL, site TEXT NOT NULL,
username TEXT NOT NULL, username TEXT NOT NULL,
password TEXT NOT NULL, password TEXT NOT NULL,
client_id TEXT NULL, client_id TEXT NOT NULL,
topic TEXT NOT NULL,
created_at timestamptz DEFAULT CURRENT_TIMESTAMP, created_at timestamptz DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT site_user_clientid_unique UNIQUE(site, username, client_id) CONSTRAINT site_user_clientid_unique UNIQUE(site, username, client_id)
) )
@@ -92,14 +93,14 @@ def authenticate(master_password):
return auth_success return auth_success
# Aggiungi una password al database # Aggiungi una password al database
def add_password(site, username, password, client_id, cipher): def add_password(site, username, password, client_id, topic, cipher):
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
encrypted_password = cipher.encrypt(password.encode()).decode() encrypted_password = cipher.encrypt(password.encode()).decode()
try: try:
cursor.execute( cursor.execute(
f"INSERT INTO {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} (site, username, password, client_id) VALUES (%s, %s, %s, %s)", f"INSERT INTO {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} (site, username, password, client_id, topic) VALUES (%s, %s, %s, %s, %s)",
(site, username, encrypted_password, client_id)) (site, username, encrypted_password, client_id, topic))
conn.commit() conn.commit()
logging.info(f"Password aggiunta per il sito: {site}.") logging.info(f"Password aggiunta per il sito: {site}.")
except psycopg2.Error as e: except psycopg2.Error as e:
@@ -112,18 +113,18 @@ def get_password(site, cipher):
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT username, password, client_id FROM {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} WHERE site = %s", (site,)) cursor.execute(f"SELECT username, password, client_id, topic FROM {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} WHERE site = %s", (site,))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
username, encrypted_password, client_id = row username, encrypted_password, client_id, topic = row
decrypted_password = cipher.decrypt(encrypted_password.encode()).decode() decrypted_password = cipher.decrypt(encrypted_password.encode()).decode()
logging.info(f"Password recuperata per il sito: {site}.") logging.info(f"Password recuperata per il sito: {site}.")
return username, decrypted_password, client_id return username, decrypted_password, client_id, topic
logging.warning(f"Sito non trovato: {site}.") logging.warning(f"Sito non trovato: {site}.")
return None, None, None return None, None, None, None
except psycopg2.Error as e: except psycopg2.Error as e:
logging.error(f"Errore durante il recupero della password: {e}") logging.error(f"Errore durante il recupero della password: {e}")
return None, None, None return None, None, None, None
finally: finally:
conn.close() conn.close()
@@ -166,6 +167,7 @@ def add_password_api():
username = request.json.get('username') username = request.json.get('username')
password = request.json.get('password') password = request.json.get('password')
client_id = request.json.get('client_id') client_id = request.json.get('client_id')
topic = request.json.get('topic')
if not authenticate(master_password): if not authenticate(master_password):
logging.warning("Tentativo di aggiungere una password con master password errata.") logging.warning("Tentativo di aggiungere una password con master password errata.")
@@ -173,7 +175,7 @@ def add_password_api():
key = derive_key(master_password) key = derive_key(master_password)
cipher = Fernet(key) cipher = Fernet(key)
add_password(site, username, password, client_id, cipher) add_password(site, username, password, client_id, topic, cipher)
return jsonify({"message": "Password aggiunta con successo"}) return jsonify({"message": "Password aggiunta con successo"})
# Endpoint per recuperare una password # Endpoint per recuperare una password
@@ -188,12 +190,12 @@ def get_password_api():
key = derive_key(master_password) key = derive_key(master_password)
cipher = Fernet(key) cipher = Fernet(key)
username, password, client_id = get_password(site, cipher) username, password, client_id, topic = get_password(site, cipher)
if username is None: if username is None:
return jsonify({"error": "Sito non trovato"}), 404 return jsonify({"error": "Sito non trovato"}), 404
return jsonify({"site": site, "username": username, "password": password, "client_id": client_id}) return jsonify({"site": site, "username": username, "password": password, "client_id": client_id, "topic": topic})
# Endpoint per cancellare una password # Endpoint per cancellare una password
@app.route('/delete', methods=['POST']) @app.route('/delete', methods=['POST'])

View File

@@ -1,14 +1,15 @@
import paho.mqtt.subscribe as subscribe import paho.mqtt.client as mqtt
from paho.mqtt.properties import Properties
from paho.mqtt.packettypes import PacketTypes
import argparse import argparse
import requests import requests
import psycopg2 import psycopg2
import json import json
import sys
import os import os
import logging import logging
# Configurazione Logging # Configurazione Logging
logging.basicConfig(level=logging.INFO, format='- PID: %(process)d %(levelname)8s: %(message)s', stream=sys.stderr) logging.basicConfig(level=logging.INFO, format='%(asctime)s - PID: %(process)d %(levelname)8s: %(message)s', filename="/var/log/ase_receiver.log")
logger = logging.getLogger() logger = logging.getLogger()
# Configurazione connessione PostgreSQL # Configurazione connessione PostgreSQL
@@ -30,10 +31,10 @@ def get_credentials(args):
} }
response = requests.post(url, json=data) response = requests.post(url, json=data)
if response.status_code != 200: if response.status_code != 200:
logger.error(f"Error to get pwd from wallet.") logger.error("Error to get pwd from wallet.")
exit(1) exit(1)
return response.json().get('password') return response.json().get('password'), response.json().get('client_id'), response.json().get('topic')
def get_db_connection(): def get_db_connection():
return psycopg2.connect( return psycopg2.connect(
@@ -44,8 +45,7 @@ def get_db_connection():
port=DB_CONFIG["port"] port=DB_CONFIG["port"]
) )
# Inizializza il database def init_db(args, main_topic):
def init_db(args):
try: try:
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
@@ -54,7 +54,8 @@ def init_db(args):
id bigserial NOT NULL, id bigserial NOT NULL,
main_topic text NOT NULL, main_topic text NOT NULL,
tt_data jsonb NULL, tt_data jsonb NULL,
created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL,
CONSTRAINT {DB_CONFIG['dbtable']}_pkey PRIMARY KEY (id, main_topic)
) )
PARTITION BY LIST (main_topic); PARTITION BY LIST (main_topic);
""") """)
@@ -65,8 +66,8 @@ def init_db(args):
""") """)
conn.commit() conn.commit()
cursor.execute(f""" cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']}_{args.client.removesuffix("_ase")} PARTITION OF {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} CREATE TABLE IF NOT EXISTS {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']}_{main_topic} PARTITION OF {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']}
FOR VALUES IN ('{args.client.removesuffix("_ase")}') FOR VALUES IN ('{main_topic}')
""") """)
conn.commit() conn.commit()
except Exception as e: except Exception as e:
@@ -77,15 +78,22 @@ def init_db(args):
logger.info("Database inizializzato.") logger.info("Database inizializzato.")
def create_nested_json(path, data): def create_nested_json(path, data):
main_topic = path.split('/')[0]
keys = path.split('/')[1:] keys = path.split('/')[1:]
nested_json = data nested_json = data
for key in reversed(keys): for key in reversed(keys):
nested_json = {key: nested_json} nested_json = {key: nested_json}
return nested_json return main_topic, nested_json
def receive_data(client, userdata, message): def on_connect(client, userdata, flags, rc, properties):
if rc == 0:
logger.info("Connesso al broker MQTT")
else:
logger.error(f"Errore di connessione, codice: {rc}")
def on_message(client, userdata, message):
datastore = json.loads(message.payload) datastore = json.loads(message.payload)
json_data = create_nested_json(message.topic, datastore) main_topic, json_data = create_nested_json(message.topic, datastore)
try: try:
conn = get_db_connection() conn = get_db_connection()
@@ -94,7 +102,7 @@ def receive_data(client, userdata, message):
INSERT INTO {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']} INSERT INTO {DB_CONFIG['dbschema']}.{DB_CONFIG['dbtable']}
(main_topic, tt_data) (main_topic, tt_data)
VALUES VALUES
('{userdata['args'].client.removesuffix("_ase")}', '{json.dumps(json_data)}'::jsonb); ('{main_topic}', '{json.dumps(json_data)}'::jsonb);
""") """)
conn.commit() conn.commit()
except Exception as e: except Exception as e:
@@ -102,31 +110,57 @@ def receive_data(client, userdata, message):
finally: finally:
conn.close() conn.close()
def on_disconnect(client, userdata, rc, properties=None):
if rc != 0:
logger.warning(f"Disconnesso dal broker con codice {rc}. Riconnessione...")
client.reconnect()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-H', '--host', default="mqtt") parser.add_argument('-H', '--host', default="mqtt")
parser.add_argument('-q', '--qos', type=int,default=2) parser.add_argument('-q', '--qos', type=int, default=1)
parser.add_argument('-P', '--port', type=int, default=1883) parser.add_argument('-P', '--port', type=int, default=1883)
parser.add_argument('-c', '--client') parser.add_argument('-c', '--client')
parser.add_argument('-w', '--wallet', default="http://mqtt:5000/") parser.add_argument('-w', '--wallet', default="http://mqtt:5000/")
args = parser.parse_args() args = parser.parse_args()
init_db(args)
auth = {'username': args.client, 'password': get_credentials(args)} password, client_id, topic = get_credentials(args)
main_topic = topic.split('/')[0]
init_db(args, main_topic)
userdata = {'args': args} userdata = {'args': args}
properties=Properties(PacketTypes.CONNECT)
properties.SessionExpiryInterval=3600
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id=client_id, protocol=mqtt.MQTTv5)
client.username_pw_set(username=args.client, password=password)
client.user_data_set(userdata)
#client.logger = logger
client.on_connect = on_connect
client.on_message = on_message
client.on_disconnect = on_disconnect
client.reconnect_delay_set(min_delay=1, max_delay=120)
client.connect(args.host, args.port, clean_start=False) #, properties=properties)
client.subscribe(topic, qos=args.qos)
try: try:
subscribe.callback(receive_data, hostname=args.host, port=args.port, logger.info("Avvio del loop MQTT.")
topics=f'{args.client.removesuffix("_ase")}/#', client.loop_forever()
qos=args.qos, clean_session=False, except KeyboardInterrupt:
auth=auth, client_id=f'{args.client.removesuffix("_ase")}_client_ase', logger.info("Terminazione manuale.")
userdata=userdata) except Exception as e:
except (KeyboardInterrupt, Exception) as e: logger.error(f"Errore durante il ciclo MQTT: {e}")
logger.info(f"Terminating: ....{e}") finally:
client.disconnect()
logger.info("Disconnesso dal broker.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()