diff --git a/openevsehttp/websocket.py b/openevsehttp/websocket.py index 96a3571..3fc83ae 100644 --- a/openevsehttp/websocket.py +++ b/openevsehttp/websocket.py @@ -51,8 +51,26 @@ def state(self): return self._state @state.setter - async def state(self, value): - """Set the state.""" + def state(self, value): + """Setter that schedules the callback.""" + self._state = value + _LOGGER.debug("Websocket %s", value) + # Schedule the callback asynchronously without awaiting here. + try: + asyncio.create_task( + self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason) + ) + except RuntimeError: + # If there's no running loop, schedule safely on the event loop. + loop = asyncio.get_event_loop() + loop.call_soon_threadsafe( + asyncio.create_task, + self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason), + ) + self._error_reason = None + + async def _set_state(self, value): + """Async helper to set the state and await the callback.""" self._state = value _LOGGER.debug("Websocket %s", value) await self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason) @@ -65,7 +83,7 @@ def _get_uri(server): async def running(self): """Open a persistent websocket connection and act on events.""" - await OpenEVSEWebsocket.state.fset(self, STATE_STARTING) + await self._set_state(STATE_STARTING) auth = None if self._user and self._password: @@ -77,7 +95,7 @@ async def running(self): heartbeat=15, auth=auth, ) as ws_client: - await OpenEVSEWebsocket.state.fset(self, STATE_CONNECTED) + await self._set_state(STATE_CONNECTED) self.failed_attempts = 0 self._client = ws_client @@ -107,11 +125,11 @@ async def running(self): else: _LOGGER.error("Unexpected response received: %s", error) self._error_reason = error - await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED) + await self._set_state(STATE_STOPPED) except (aiohttp.ClientConnectionError, asyncio.TimeoutError) as error: if self.failed_attempts > MAX_FAILED_ATTEMPTS: self._error_reason = ERROR_TOO_MANY_RETRIES - await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED) + await self._set_state(STATE_STOPPED) elif self.state != STATE_STOPPED: retry_delay = min(2 ** (self.failed_attempts - 1) * 30, 300) self.failed_attempts += 1 @@ -120,16 +138,16 @@ async def running(self): retry_delay, error, ) - await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED) + await self._set_state(STATE_DISCONNECTED) await asyncio.sleep(retry_delay) except Exception as error: # pylint: disable=broad-except if self.state != STATE_STOPPED: _LOGGER.exception("Unexpected exception occurred: %s", error) self._error_reason = error - await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED) + await self._set_state(STATE_STOPPED) else: if self.state != STATE_STOPPED: - await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED) + await self._set_state(STATE_DISCONNECTED) await asyncio.sleep(5) async def listen(self): @@ -140,7 +158,7 @@ async def listen(self): async def close(self): """Close the listening websocket.""" - await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED) + await self._set_state(STATE_STOPPED) await self.session.close() async def keepalive(self): @@ -151,7 +169,7 @@ async def keepalive(self): # Negitive time should indicate no pong reply so consider the # websocket disconnected. self._error_reason = ERROR_PING_TIMEOUT - await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED) + await self._set_state(STATE_DISCONNECTED) data = {"ping": 1} _LOGGER.debug("Sending message: %s to websocket.", data) @@ -168,7 +186,7 @@ async def keepalive(self): _LOGGER.error("Error parsing data: %s", err) except RuntimeError as err: _LOGGER.debug("Websocket connection issue: %s", err) - await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED) + await self._set_state(STATE_DISCONNECTED) except Exception as err: # pylint: disable=broad-exception-caught _LOGGER.debug("Problem sending ping request: %s", err) - await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED) + await self._set_state(STATE_DISCONNECTED) diff --git a/tests/conftest.py b/tests/conftest.py index 3ec4795..7108967 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ """Provide common pytest fixtures.""" -import json - import pytest from aioresponses import aioresponses diff --git a/tests/test_main.py b/tests/test_main.py index bb801eb..e807beb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,5 @@ """Library tests.""" -import asyncio import json import logging from unittest import mock @@ -12,7 +11,6 @@ from freezegun import freeze_time from aiohttp.client_exceptions import ContentTypeError, ServerTimeoutError from aiohttp.client_reqrep import ConnectionKey -from awesomeversion.exceptions import AwesomeVersionCompareException import openevsehttp.__main__ as main from openevsehttp.__main__ import OpenEVSE @@ -23,8 +21,6 @@ UnsupportedFeature, ) from openevsehttp.websocket import ( - SIGNAL_CONNECTION_STATE, - STATE_CONNECTED, STATE_DISCONNECTED, ) from tests.common import load_fixture @@ -1118,7 +1114,7 @@ async def test_firmware_check( body="", ) firmware = await test_charger.firmware_check() - assert firmware == None + assert firmware is None mock_aioclient.get( TEST_URL_GITHUB_v4, @@ -2177,7 +2173,7 @@ async def test_get_shaper_updated(fixture, expected, request): await charger.ws_disconnect() -async def test_get_status(test_charger_timeout, caplog): +async def test_get_status_error(test_charger_timeout, caplog): """Test v4 Status reply.""" with caplog.at_level(logging.DEBUG): with pytest.raises(TimeoutError): diff --git a/tests/test_websocket.py b/tests/test_websocket.py index e80ef5c..d97d7aa 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -280,3 +280,27 @@ async def test_keepalive_send_exceptions(ws_client_auth): ws_client_auth._client.send_json.side_effect = Exception("Generic err") await ws_client_auth.keepalive() assert ws_client_auth.state == STATE_DISCONNECTED + + +@pytest.mark.asyncio +async def test_state_setter_threadsafe_fallback(ws_client): + """Test state setter falls back to call_soon_threadsafe on RuntimeError.""" + mock_loop = MagicMock() + ws_client._error_reason = "Previous Error" + + with ( + patch( + "asyncio.create_task", side_effect=RuntimeError("No running loop") + ) as mock_create_task, + patch("asyncio.get_event_loop", return_value=mock_loop), + ): + + ws_client.state = STATE_CONNECTED + assert ws_client.state == STATE_CONNECTED + + mock_loop.call_soon_threadsafe.assert_called_once() + + args, _ = mock_loop.call_soon_threadsafe.call_args + assert args[0] is mock_create_task + + assert ws_client._error_reason is None