This commit is contained in:
Moritz Ruth 2024-03-05 23:35:51 +01:00
parent 13f307d387
commit 3e129bd3d6
Signed by: moritzruth
GPG key ID: C9BBAB79405EE56D
20 changed files with 1245 additions and 41 deletions

View file

@ -0,0 +1,20 @@
[package]
name = "home_assistant"
version = "0.1.0"
edition = "2021"
[dependencies]
deckster_mode = { path = "../../crates/deckster_mode" }
clap = { version = "4.4.18", features = ["derive"] }
color-eyre = "0.6.2"
env_logger = "0.11.1"
log = "0.4.20"
tokio = { version = "1.35.1", features = ["macros", "parking_lot", "rt", "sync"] }
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.114"
reqwest = "0.11.24"
url = { version = "2.5.0", features = ["serde"] }
tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] }
tokio-stream = "0.1.14"
futures-util = "0.3.30"
native-tls = "0.2.11"

View file

@ -0,0 +1,51 @@
use deckster_mode::shared::state::KeyStyleByStateMap;
use serde::Deserialize;
use url::Url;
#[derive(Debug, Clone, Deserialize)]
pub struct GlobalConfig {
pub base_url: Url,
pub token: Box<str>,
#[serde(default)]
pub accept_invalid_certs: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct KeyConfig {
pub disconnected_state: Option<Box<str>>,
#[serde(flatten)]
pub mode: KeyMode,
pub style: KeyStyleByStateMap<Box<str>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "mode", rename_all = "kebab-case")]
pub enum KeyMode {
Toggle { entity_id: Box<str> },
Button { state_entity_id: Box<str>, button_entity_id: Box<str> },
}
impl KeyMode {
pub fn state_entity_id(&self) -> &Box<str> {
match &self {
KeyMode::Toggle { entity_id, .. } => entity_id,
KeyMode::Button { state_entity_id, .. } => state_entity_id,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct KnobConfig {
pub(crate) entity_id: Box<str>,
pub disconnected_state: Option<Box<str>>,
#[serde(flatten)]
pub mode: KnobMode,
pub style: KeyStyleByStateMap<Box<str>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "mode", rename_all = "kebab-case")]
pub enum KnobMode {
Select { states: Box<[Box<str>]>, wrap_around: bool },
Range,
}

View file

@ -0,0 +1,289 @@
use futures_util::SinkExt;
use native_tls::TlsConnector;
use reqwest::header::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, RwLock};
use tokio_stream::StreamExt;
use tokio_tungstenite::{tungstenite, Connector};
use url::Url;
#[derive(Debug, Clone)]
pub enum StateUpdate {
Disconnected,
Actual(Arc<ActualStateUpdate>),
}
#[derive(Debug)]
pub struct ActualStateUpdate {
pub entity_id: Box<str>,
pub state: Box<str>,
pub timestamp: Box<str>,
}
#[derive(Debug, Clone)]
pub struct HaClient {
state_updates_sender: broadcast::Sender<StateUpdate>,
http_client: reqwest::Client,
base_url: Url,
}
impl HaClient {
pub async fn new(base_url: Url, token: Box<str>, accept_invalid_certs: bool, subscribed_entity_ids: Vec<Box<str>>) -> Self {
let http_client = reqwest::ClientBuilder::new()
.connect_timeout(Duration::from_secs(10))
.default_headers({
let mut map = HeaderMap::new();
map.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {token}")).expect("the token generated by Home Assistant only contains valid characters"),
);
map
})
.danger_accept_invalid_certs(accept_invalid_certs)
.user_agent(format!("home_assistant deckster handler (v{})", env!("CARGO_PKG_VERSION")))
.build()
.unwrap(); // The HTTP client being available is essential.
let state_updates_sender = broadcast::Sender::<StateUpdate>::new(min(subscribed_entity_ids.len(), 16));
let state_timestamp_by_entity_id = subscribed_entity_ids.iter().map(|i| (i.clone(), "".to_owned().into_boxed_str())).collect();
let tls_connector = TlsConnector::builder().danger_accept_invalid_certs(accept_invalid_certs).build().unwrap();
tokio::spawn(do_work(
base_url.clone(),
token,
tls_connector,
state_updates_sender.clone(),
http_client.clone(),
state_timestamp_by_entity_id,
));
if log::max_level() <= log::Level::Debug {
let mut updates = state_updates_sender.subscribe();
tokio::spawn(async move {
while let Ok(u) = updates.recv().await {
log::debug!("State update: {u:?}")
}
});
}
HaClient {
state_updates_sender,
http_client,
base_url,
}
}
pub fn subscribe_to_state_updates(&self) -> broadcast::Receiver<StateUpdate> {
self.state_updates_sender.subscribe()
}
pub async fn toggle_entity(&self, entity_id: &str) {
let (domain, _) = entity_id.split_once('.').expect("entity IDs must contain exactly one dot");
let result = self
.http_client
.post(self.base_url.join(&format!("/api/services/{domain}/toggle")).unwrap())
.body(format!("{{\"entity_id\":\"{entity_id}\"}}"))
.send()
.await
.and_then(|a| a.error_for_status());
if let Err(error) = result {
log::error!(
"POST request to {} failed: {error}",
error.url().map(|u| u.to_string()).unwrap_or("?".to_owned())
)
}
}
}
async fn do_work(
base_url: Url,
token: Box<str>,
tls_connector: TlsConnector,
state_updates_sender: broadcast::Sender<StateUpdate>,
http_client: reqwest::Client,
state_timestamp_by_entity_id: HashMap<Box<str>, Box<str>>,
) {
let states_url = base_url.join("/api/states/").unwrap();
let websocket_url = {
let mut u = base_url.clone();
u.set_scheme(&u.scheme().replace("http", "ws")).unwrap();
u.set_path("api/websocket");
u.to_string()
};
let mut is_first_connection_attempt = true;
let state_timestamp_by_entity_id = Arc::new(RwLock::new(state_timestamp_by_entity_id));
loop {
let connection_result =
tokio_tungstenite::connect_async_tls_with_config(&websocket_url, None, false, Some(Connector::NativeTls(tls_connector.clone()))).await;
match connection_result {
Err(tungstenite::Error::Io(error)) => {
if is_first_connection_attempt {
log::warn!("Establishing a WebSocket connection failed: {error}");
log::info!("Retrying every 5 seconds…")
}
is_first_connection_attempt = false;
tokio::time::sleep(Duration::from_secs(5)).await;
}
Err(error) => panic!("WebSocket error: {}", error),
Ok((mut socket, _)) => {
log::info!("WebSocket connection successfully established.");
while let Some(event) = socket.next().await {
match event {
Err(error) => {
log::error!("The WebSocket connection failed: {error}");
break;
}
Ok(message) => match message {
tungstenite::Message::Ping(data) => socket.send(tungstenite::Message::Pong(data)).await.unwrap(),
tungstenite::Message::Text(data) => {
log::trace!("Received WebSocket message: {data}");
match serde_json::from_str::<HaIncomingWsMessage>(&data) {
Err(error) => log::error!("Deserializing WebSocket message failed: {error}"),
Ok(message) => match message {
HaIncomingWsMessage::AuthRequired { .. } => socket
.send(tungstenite::Message::Text(
serde_json::to_string(&HaOutgoingWsMessage::Auth { access_token: token.clone() }).unwrap(),
))
.await
.unwrap(),
HaIncomingWsMessage::AuthInvalid { .. } => panic!("Invalid access token."),
HaIncomingWsMessage::AuthOk { .. } => {
let subscription_message = serde_json::to_string_pretty(&HaOutgoingWsMessage::SubscribeTrigger {
// ID may not be zero (that one took me a while)
id: 1,
trigger: HaTrigger::State {
entity_id: state_timestamp_by_entity_id.read().await.keys().cloned().collect(),
// Setting from to null prevents events being sent when only attributes have changed.
from: serde_json::Value::Null,
},
})
.unwrap();
socket.send(tungstenite::Message::Text(subscription_message)).await.unwrap();
}
HaIncomingWsMessage::Result { id, success } => {
if !success {
panic!("A command ({id}) failed.");
}
for entity_id in state_timestamp_by_entity_id.read().await.keys() {
tokio::spawn(request_entity_state(
states_url.join(entity_id).unwrap(),
http_client.clone(),
Arc::clone(&state_timestamp_by_entity_id),
state_updates_sender.clone(),
));
}
}
HaIncomingWsMessage::Event { event, .. } => match extract_state_update_from_event(&event) {
None => log::error!("Invalid state change event message: {data}"),
Some(update) => {
// LOCK START
let mut state_timestamp_by_entity_id = state_timestamp_by_entity_id.write().await;
match state_timestamp_by_entity_id.get(&update.entity_id) {
None => log::warn!("Received unwanted state change event for entity '{}'", update.entity_id),
Some(last_timestamp) => {
if last_timestamp < &update.timestamp {
state_timestamp_by_entity_id.insert(update.entity_id.clone(), update.timestamp.clone());
state_updates_sender.send(StateUpdate::Actual(Arc::new(update))).unwrap();
}
}
}
// LOCK END
}
},
},
}
}
_ => log::error!("Received unsupported WebSocket message: {message:?}"),
},
}
}
}
};
}
}
async fn request_entity_state(
url: Url,
http_client: reqwest::Client,
state_timestamp_by_entity_id: Arc<RwLock<HashMap<Box<str>, Box<str>>>>,
state_updates_sender: broadcast::Sender<StateUpdate>,
) {
match http_client.get(url).send().await.and_then(|a| a.error_for_status()) {
Err(error) => log::error!(
"A GET request to {} failed: {error}",
error.url().map(|u| u.to_string()).unwrap_or("?".to_owned())
),
Ok(response) => match serde_json::from_str(&response.text().await.unwrap()) {
Ok(object) => match extract_state_update_from_state(&object) {
None => log::error!("Invalid entity state object: {object}"),
Some(update) => {
// LOCK START
let mut state_timestamp_by_entity_id = state_timestamp_by_entity_id.write().await;
let last_timestamp = state_timestamp_by_entity_id
.get(&update.entity_id)
.expect("Home Assistant responds with the state of the requested entity.");
if last_timestamp < &update.timestamp {
state_timestamp_by_entity_id.insert(update.entity_id.clone(), update.timestamp.clone());
state_updates_sender.send(StateUpdate::Actual(Arc::new(update))).unwrap();
}
// LOCK END
}
},
Err(error) => {
log::error!("Failed to deserialize state object: {error}");
}
},
}
}
fn extract_state_update_from_event(object: &serde_json::Value) -> Option<ActualStateUpdate> {
extract_state_update_from_state(object.get("variables")?.get("trigger")?.get("to_state")?)
}
fn extract_state_update_from_state(object: &serde_json::Value) -> Option<ActualStateUpdate> {
Some(ActualStateUpdate {
state: object.get("state")?.as_str()?.to_owned().into_boxed_str(),
entity_id: object.get("entity_id")?.as_str()?.to_owned().into_boxed_str(),
timestamp: object.get("last_changed")?.as_str()?.to_owned().into_boxed_str(),
})
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HaIncomingWsMessage {
AuthRequired { ha_version: Box<str> },
AuthOk { ha_version: Box<str> },
AuthInvalid { message: Box<str> },
Result { id: usize, success: bool },
Event { id: usize, event: serde_json::Value },
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HaOutgoingWsMessage {
Auth { access_token: Box<str> },
SubscribeTrigger { id: usize, trigger: HaTrigger },
}
#[derive(Debug, Serialize)]
#[serde(tag = "platform", rename_all = "snake_case")]
pub enum HaTrigger {
State { entity_id: Box<[Box<str>]>, from: serde_json::Value },
}

View file

@ -0,0 +1,116 @@
use crate::config::{GlobalConfig, KeyConfig, KeyMode, KnobConfig, KnobMode};
use crate::ha_client::{HaClient, StateUpdate};
use deckster_mode::shared::handler_communication::{HandlerCommand, HandlerEvent, HandlerInitializationError, InitialHandlerMessage, KeyEvent};
use deckster_mode::shared::path::KeyPath;
use deckster_mode::{send_command, DecksterHandler};
use std::thread;
use tokio::select;
use tokio::sync::broadcast;
use tokio::task::LocalSet;
pub struct Handler {
events_sender: broadcast::Sender<HandlerEvent>,
}
impl Handler {
pub fn new(data: InitialHandlerMessage<GlobalConfig, KeyConfig, KnobConfig>) -> Result<Self, HandlerInitializationError> {
let events_sender = broadcast::Sender::<HandlerEvent>::new(5);
let mut subscribed_entity_ids = Vec::new();
for c in data.key_configs.values() {
subscribed_entity_ids.push(c.mode.state_entity_id().clone())
}
for c in data.knob_configs.values() {
subscribed_entity_ids.push(c.entity_id.clone())
}
thread::spawn({
let events_sender = events_sender.clone();
move || {
let runtime = tokio::runtime::Builder::new_current_thread().enable_time().enable_io().build().unwrap();
let task_set = LocalSet::new();
let ha_client = task_set.block_on(
&runtime,
HaClient::new(
data.global_config.base_url,
data.global_config.token,
data.global_config.accept_invalid_certs,
subscribed_entity_ids,
),
);
for (path, config) in data.key_configs {
task_set.spawn_local(manage_key(events_sender.subscribe(), ha_client.clone(), path, config));
}
runtime.block_on(task_set)
}
});
Ok(Handler { events_sender })
}
}
impl DecksterHandler for Handler {
fn handle(&mut self, event: HandlerEvent) {
// No receivers being available can be ignored.
_ = self.events_sender.send(event);
}
}
async fn manage_key(mut events: broadcast::Receiver<HandlerEvent>, ha_client: HaClient, path: KeyPath, config: KeyConfig) {
let state_entity_id = config.mode.state_entity_id();
if let Some(state) = &config.disconnected_state {
send_command(HandlerCommand::SetKeyStyle {
path: path.clone(),
value: config.style.get(state).cloned(),
})
}
let mut state_updates = ha_client.subscribe_to_state_updates();
loop {
select! {
Ok(update) = state_updates.recv() => {
match update {
StateUpdate::Disconnected => {
if let Some(state) = &config.disconnected_state {
send_command(HandlerCommand::SetKeyStyle {
path: path.clone(),
value: config.style.get(state).cloned()
})
}
}
StateUpdate::Actual(update) => {
if &update.entity_id == state_entity_id {
send_command(HandlerCommand::SetKeyStyle {
path: path.clone(),
value: config.style.get(&update.state).cloned()
})
}
}
}
}
Ok(HandlerEvent::Key { path: p, event }) = events.recv() => {
if p != path {
continue
}
if let KeyEvent::Press = event {
match &config.mode {
KeyMode::Toggle { entity_id } => {
ha_client.toggle_entity(entity_id).await;
}
KeyMode::Button { .. } => {
todo!()
}
}
}
}
}
}
}

View file

@ -0,0 +1,29 @@
use clap::Parser;
use color_eyre::Result;
use crate::handler::Handler;
mod config;
mod ha_client;
mod handler;
mod util;
#[derive(Debug, Parser)]
#[command(name = "home_assistant")]
enum CliCommand {
#[command(name = "deckster-run", hide = true)]
Run,
}
fn main() -> Result<()> {
env_logger::init();
let command = CliCommand::parse();
match command {
CliCommand::Run => {
deckster_mode::run(Handler::new)?;
}
}
Ok(())
}

View file

@ -0,0 +1,54 @@
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::time::timeout;
/// Sends a message into the output channel after a message in the input channel was received, with a delay of `duration`.
/// The delay is reset when a new message is reset.
pub fn spawn_debouncer(duration: Duration) -> (Sender<()>, Receiver<()>) {
let (input_sender, mut input_receiver) = mpsc::channel::<()>(1);
let (output_sender, output_receiver) = mpsc::channel::<()>(1);
tokio::spawn(async move {
'outer: loop {
if input_receiver.recv().await.is_none() {
break 'outer;
}
'inner: loop {
match timeout(duration, input_receiver.recv()).await {
Ok(None) => break 'outer,
Ok(Some(_)) => continue 'inner,
Err(_) => {
if let Err(TrySendError::Closed(_)) = output_sender.try_send(()) {
break 'outer;
} else {
break 'inner;
}
}
}
}
}
});
(input_sender, output_receiver)
}
pub fn format_duration(duration: Duration) -> String {
let full_seconds = duration.as_secs();
let full_minutes = full_seconds / 60;
let hours = full_minutes / 60;
let minutes = full_minutes % 60;
let seconds = full_seconds % 60;
if hours == 0 {
format!("{:0>2}:{:0>2}", minutes, seconds)
} else {
format!("{:0>2}:{:0>2}:{:0>2}", hours, minutes, seconds)
}
}
pub fn get_far_future() -> Instant {
Instant::now() + Duration::from_secs(60 * 60 * 24 * 365 * 30) // 30 years
}