Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions openevsehttp/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,26 @@ def state(self):
return self._state

@state.setter
async def state(self, value):
"""Set the state."""
def state(self, value):
"""Setter that schedules the callback."""
self._state = value
_LOGGER.debug("Websocket %s", value)
# Schedule the callback asynchronously without awaiting here.
try:
asyncio.create_task(
self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason)
)
except RuntimeError:
# If there's no running loop, schedule safely on the event loop.
loop = asyncio.get_event_loop()
loop.call_soon_threadsafe(
asyncio.create_task,
self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason),
)
self._error_reason = None

async def _set_state(self, value):
"""Async helper to set the state and await the callback."""
self._state = value
_LOGGER.debug("Websocket %s", value)
await self.callback(SIGNAL_CONNECTION_STATE, value, self._error_reason)
Expand All @@ -65,7 +83,7 @@ def _get_uri(server):

async def running(self):
"""Open a persistent websocket connection and act on events."""
await OpenEVSEWebsocket.state.fset(self, STATE_STARTING)
await self._set_state(STATE_STARTING)
auth = None

if self._user and self._password:
Expand All @@ -77,7 +95,7 @@ async def running(self):
heartbeat=15,
auth=auth,
) as ws_client:
await OpenEVSEWebsocket.state.fset(self, STATE_CONNECTED)
await self._set_state(STATE_CONNECTED)
self.failed_attempts = 0
self._client = ws_client

Expand Down Expand Up @@ -107,11 +125,11 @@ async def running(self):
else:
_LOGGER.error("Unexpected response received: %s", error)
self._error_reason = error
await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED)
await self._set_state(STATE_STOPPED)
except (aiohttp.ClientConnectionError, asyncio.TimeoutError) as error:
if self.failed_attempts > MAX_FAILED_ATTEMPTS:
self._error_reason = ERROR_TOO_MANY_RETRIES
await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED)
await self._set_state(STATE_STOPPED)
elif self.state != STATE_STOPPED:
retry_delay = min(2 ** (self.failed_attempts - 1) * 30, 300)
self.failed_attempts += 1
Expand All @@ -120,16 +138,16 @@ async def running(self):
retry_delay,
error,
)
await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED)
await self._set_state(STATE_DISCONNECTED)
await asyncio.sleep(retry_delay)
except Exception as error: # pylint: disable=broad-except
if self.state != STATE_STOPPED:
_LOGGER.exception("Unexpected exception occurred: %s", error)
self._error_reason = error
await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED)
await self._set_state(STATE_STOPPED)
else:
if self.state != STATE_STOPPED:
await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED)
await self._set_state(STATE_DISCONNECTED)
await asyncio.sleep(5)

async def listen(self):
Expand All @@ -140,7 +158,7 @@ async def listen(self):

async def close(self):
"""Close the listening websocket."""
await OpenEVSEWebsocket.state.fset(self, STATE_STOPPED)
await self._set_state(STATE_STOPPED)
await self.session.close()

async def keepalive(self):
Expand All @@ -151,7 +169,7 @@ async def keepalive(self):
# Negitive time should indicate no pong reply so consider the
# websocket disconnected.
self._error_reason = ERROR_PING_TIMEOUT
await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED)
await self._set_state(STATE_DISCONNECTED)

data = {"ping": 1}
_LOGGER.debug("Sending message: %s to websocket.", data)
Expand All @@ -168,7 +186,7 @@ async def keepalive(self):
_LOGGER.error("Error parsing data: %s", err)
except RuntimeError as err:
_LOGGER.debug("Websocket connection issue: %s", err)
await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED)
await self._set_state(STATE_DISCONNECTED)
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.debug("Problem sending ping request: %s", err)
await OpenEVSEWebsocket.state.fset(self, STATE_DISCONNECTED)
await self._set_state(STATE_DISCONNECTED)
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Provide common pytest fixtures."""

import json

import pytest
from aioresponses import aioresponses

Expand Down
8 changes: 2 additions & 6 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Library tests."""

import asyncio
import json
import logging
from unittest import mock
Expand All @@ -12,7 +11,6 @@
from freezegun import freeze_time
from aiohttp.client_exceptions import ContentTypeError, ServerTimeoutError
from aiohttp.client_reqrep import ConnectionKey
from awesomeversion.exceptions import AwesomeVersionCompareException

import openevsehttp.__main__ as main
from openevsehttp.__main__ import OpenEVSE
Expand All @@ -23,8 +21,6 @@
UnsupportedFeature,
)
from openevsehttp.websocket import (
SIGNAL_CONNECTION_STATE,
STATE_CONNECTED,
STATE_DISCONNECTED,
)
from tests.common import load_fixture
Expand Down Expand Up @@ -1118,7 +1114,7 @@ async def test_firmware_check(
body="",
)
firmware = await test_charger.firmware_check()
assert firmware == None
assert firmware is None

mock_aioclient.get(
TEST_URL_GITHUB_v4,
Expand Down Expand Up @@ -2177,7 +2173,7 @@ async def test_get_shaper_updated(fixture, expected, request):
await charger.ws_disconnect()


async def test_get_status(test_charger_timeout, caplog):
async def test_get_status_error(test_charger_timeout, caplog):
"""Test v4 Status reply."""
with caplog.at_level(logging.DEBUG):
with pytest.raises(TimeoutError):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,27 @@ async def test_keepalive_send_exceptions(ws_client_auth):
ws_client_auth._client.send_json.side_effect = Exception("Generic err")
await ws_client_auth.keepalive()
assert ws_client_auth.state == STATE_DISCONNECTED


@pytest.mark.asyncio
async def test_state_setter_threadsafe_fallback(ws_client):
"""Test state setter falls back to call_soon_threadsafe on RuntimeError."""
mock_loop = MagicMock()
ws_client._error_reason = "Previous Error"

with (
patch(
"asyncio.create_task", side_effect=RuntimeError("No running loop")
) as mock_create_task,
patch("asyncio.get_event_loop", return_value=mock_loop),
):

ws_client.state = STATE_CONNECTED
assert ws_client.state == STATE_CONNECTED

mock_loop.call_soon_threadsafe.assert_called_once()

args, _ = mock_loop.call_soon_threadsafe.call_args
assert args[0] is mock_create_task

assert ws_client._error_reason is None