diff --git a/README.md b/README.md index d8e53f9..2562888 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ > Integrates any Linux machine into your Home Assistant ecosystem. ## Features -- [ ] Triggers - - [ ] shutdown, reboot - - [ ] Custom commands +- [ ] Fallback MQTT broker address +- [x] Command buttons - [ ] Notifications - [ ] Actions - [ ] System stats diff --git a/src/config.rs b/src/config.rs index 6b4ca81..227be26 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,6 +11,8 @@ use regex::Regex; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError}; +use crate::modules; + #[derive(Serialize, Deserialize, Validate)] pub struct Mqtt { #[validate(length(min = 1))] @@ -35,20 +37,26 @@ pub struct Internal { } #[derive(Serialize, Deserialize, Validate)] -pub struct Config { +pub(crate) struct Config { #[validate(custom = "validate_unique_id")] pub unique_id: String, + #[validate(length(min = 1))] pub display_name: String, pub announce_mac_address: bool, + #[validate] pub mqtt: Mqtt, + #[serde(rename = "DO_NOT_CHANGE")] #[validate] pub internal: Internal, + + #[validate] + pub command_buttons: Option, } -fn validate_unique_id(value: &str) -> Result<(), ValidationError> { +pub(crate) fn validate_unique_id(value: &str) -> Result<(), ValidationError> { if Regex::new(r"^[a-zA-Z0-9]+(_[a-zA-Z0-9]+)*$").unwrap().is_match(value) { Ok(()) } else { @@ -73,6 +81,10 @@ fn create_example_config() -> Config { internal: Internal { stable_id: generate_unique_id(), }, + command_buttons: Some(modules::command_buttons::Config { + enabled: false, + buttons: Vec::new(), + }) } } diff --git a/src/main.rs b/src/main.rs index 17ed216..26994cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,9 +71,5 @@ async fn main() -> Result<()> { modules::init_all(&mut module_context).await?; - owned_topics_service - .clear_old_and_save_new(module_context.mqtt.client, &module_context.mqtt.owned_topics) - .await?; - - mqtt::start_communication(&module_context, event_loop).await + mqtt::start_communication(&module_context, event_loop, owned_topics_service).await } diff --git a/src/modules/power.rs b/src/modules/command_buttons.rs similarity index 51% rename from src/modules/power.rs rename to src/modules/command_buttons.rs index cb5b3b0..2f0e1cb 100644 --- a/src/modules/power.rs +++ b/src/modules/command_buttons.rs @@ -1,24 +1,55 @@ use anyhow::Result; +use serde::{Deserialize, Serialize}; use tokio::process::Command; +use validator::Validate; + +use crate::config::validate_unique_id; use super::ModuleContext; const MODULE_ID: &str = "power"; const BUTTON_TRIGGER_TEXT: &str = "press"; -pub(crate) async fn init(context: &mut ModuleContext<'_>) -> Result<()> { - log::info!("Initializing…"); +#[derive(Serialize, Deserialize, Validate, Clone)] +pub(crate) struct ButtonConfig { + #[validate(custom = "validate_unique_id")] + pub id: String, - init_command_button(context, "shutdown", "Shutdown", "shutdown -h now").await?; - init_command_button(context, "reboot", "Reboot", "shutdown -r now").await?; + #[validate(length(min = 1))] + pub name: String, + + #[validate(length(min = 1))] + pub command: String, + + #[serde(default)] + pub run_in_shell: bool, +} + +#[derive(Serialize, Deserialize, Validate)] +pub(crate) struct Config { + #[serde(default)] + pub enabled: bool, + + #[serde(default)] + pub buttons: Vec, +} + +pub(crate) async fn init(context: &mut ModuleContext<'_>) -> Result<()> { + let config = match &context.config.command_buttons { + Some(c) if c.enabled => c, + _ => return Ok(()) + }; + + log::info!("Initializing…"); + for button in config.buttons.iter() { + init_command_button(context, button.clone()).await?; + } Ok(()) } -async fn init_command_button(context: &mut ModuleContext<'_>, sub_id: &str, name: &str, command: impl Into) -> Result<()> { - let command = command.into(); - - let entity_id = context.get_entity_id(MODULE_ID, sub_id); +async fn init_command_button(context: &mut ModuleContext<'_>, config: ButtonConfig) -> Result<()> { + let entity_id = context.get_entity_id(MODULE_ID, &config.id); let command_topic = context.mqtt.get_topic("button", &entity_id, "trigger"); context @@ -30,7 +61,7 @@ async fn init_command_button(context: &mut ModuleContext<'_>, sub_id: &str, name "command_topic": command_topic.as_str(), "device": context.mqtt.discovery_device_object.clone(), "icon": "mdi:power", - "name": name, + "name": config.name, "payload_press": BUTTON_TRIGGER_TEXT, "object_id": entity_id.as_str(), "unique_id": entity_id.as_str() @@ -40,7 +71,9 @@ async fn init_command_button(context: &mut ModuleContext<'_>, sub_id: &str, name context.mqtt.subscribe(command_topic, move |text| { if text == BUTTON_TRIGGER_TEXT { - run_command(command.clone()); + run_command(config.command.clone(), config.run_in_shell); + } else { + log::warn!("Received invalid trigger text for button {}", config.id) } Ok(()) @@ -49,7 +82,7 @@ async fn init_command_button(context: &mut ModuleContext<'_>, sub_id: &str, name Ok(()) } -fn run_command(command: String) { +fn run_command(command: String, in_shell: bool) { tokio::spawn(async move { let is_dry_run = cfg!(feature = "dry_run"); @@ -57,8 +90,16 @@ fn run_command(command: String) { log::info!("Executing command{}: {}", if is_dry_run { " (dry run)" } else { "" }, command); let mut command_parts = command.split(' ').collect::>(); - let mut actual_command = Command::new(command_parts[0]); - command_parts.remove(0); + let mut actual_command = if in_shell { + let mut c = Command::new("/bin/sh"); + c.arg("-lc"); + c + } else { + let c = Command::new(command_parts[0]); + command_parts.remove(0); + c + }; + actual_command.args(command_parts); if is_dry_run { diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 8b8f980..270e702 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -4,7 +4,7 @@ use anyhow::Result; use json::JsonValue; use rumqttc::{AsyncClient as MqttClient, ClientError, QoS}; -pub mod power; +pub mod command_buttons; type MqttMessageHandler<'a> = dyn Fn(&str) -> Result<()> + 'a; @@ -18,7 +18,7 @@ pub struct ModuleContextMqtt<'a> { pub owned_topics: HashSet, } -pub struct ModuleContext<'a> { +pub(crate) struct ModuleContext<'a> { pub config: &'a super::config::Config, pub mqtt: ModuleContextMqtt<'a>, } @@ -47,7 +47,7 @@ impl<'a> ModuleContextMqtt<'a> { } } -pub async fn init_all(context: &mut ModuleContext<'_>) -> Result<()> { - power::init(context).await?; +pub(crate) async fn init_all(context: &mut ModuleContext<'_>) -> Result<()> { + command_buttons::init(context).await?; Ok(()) } diff --git a/src/mqtt.rs b/src/mqtt.rs index 8e02c84..d5d3751 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -16,7 +16,7 @@ use crate::modules::ModuleContext; use super::config; -pub async fn create_client(config: &config::Config, availability_topic: &str) -> Result<(MqttClient, EventLoop)> { +pub(crate) async fn create_client(config: &config::Config, availability_topic: &str) -> Result<(MqttClient, EventLoop)> { let mut options = MqttOptions::new(&config.internal.stable_id, config.mqtt.host.to_owned(), config.mqtt.port); options.set_clean_session(true); options.set_keep_alive(Duration::from_secs(5)); @@ -28,7 +28,7 @@ pub async fn create_client(config: &config::Config, availability_topic: &str) -> Ok((mqtt_client, event_loop)) } -pub fn create_discovery_device_object(config: &config::Config) -> JsonValue { +pub(crate) fn create_discovery_device_object(config: &config::Config) -> JsonValue { json::object! { "connections": if config.announce_mac_address { mac_address::get_mac_address().unwrap_or(None).map(|a| json::array![["mac", a.to_string()]]).unwrap_or(json::array![]) @@ -45,7 +45,7 @@ pub struct OwnedTopicsService { } impl OwnedTopicsService { - pub async fn new(data_directory_path: &Path) -> Result { + pub(crate) async fn new(data_directory_path: &Path) -> Result { let path = data_directory_path.join("owned_topics"); let mut file = OpenOptions::new().write(true).read(true).create(true).open(path).await?; @@ -57,7 +57,7 @@ impl OwnedTopicsService { Ok(OwnedTopicsService { file, old_topics }) } - pub async fn clear_old_and_save_new(mut self, mqtt_client: &MqttClient, new_topics: &HashSet) -> Result<()> { + pub(crate) async fn clear_old_and_save_new(mut self, mqtt_client: &MqttClient, new_topics: &HashSet) -> Result<()> { let unused_topics = self.old_topics.difference(new_topics).map(|s| s.to_owned()).collect::>(); log::info!( @@ -100,22 +100,25 @@ const FAST_RETRYING_INTERVAL_MS: u64 = 500; const FAST_RETRYING_LIMIT_SECONDS: u64 = 15; const SLOW_RETRYING_INTERVAL_SECONDS: u64 = 5; -pub async fn start_communication(context: &ModuleContext<'_>, mut event_loop: EventLoop) -> Result<()> { +pub(crate) async fn start_communication(context: &ModuleContext<'_>, mut event_loop: EventLoop, owned_topics_service: OwnedTopicsService) -> Result<()> { log::info!("Connecting to MQTT broker at {}:{}", context.config.mqtt.host, context.config.mqtt.port); - context - .mqtt - .client - .subscribe_many( - context - .mqtt - .message_handler_by_topic - .keys() - .map(|k| SubscribeFilter::new(k.to_owned(), QoS::AtLeastOnce)), - ) - .await?; + if !context.mqtt.message_handler_by_topic.is_empty() { + context + .mqtt + .client + .subscribe_many( + context + .mqtt + .message_handler_by_topic + .keys() + .map(|k| SubscribeFilter::new(k.to_owned(), QoS::AtLeastOnce)), + ) + .await?; + } let mut connection_state = ConnectionState::NotConnected; + let mut owned_topics_service = Some(owned_topics_service); loop { match connection_state { @@ -128,6 +131,7 @@ pub async fn start_communication(context: &ModuleContext<'_>, mut event_loop: Ev FAST_RETRYING_LIMIT_SECONDS, SLOW_RETRYING_INTERVAL_SECONDS ); + connection_state = ConnectionState::SlowRetrying; } } @@ -168,6 +172,11 @@ pub async fn start_communication(context: &ModuleContext<'_>, mut event_loop: Ev log::info!("Connection restored") } + if let Some(service) = owned_topics_service.take() { + service.clear_old_and_save_new(context.mqtt.client, &context.mqtt.owned_topics) + .await?; + } + connection_state = ConnectionState::Connected; }