Improve ergonomics of FlowManager.async_show_progress (#107668)
* Improve ergonomics of FlowManager.async_show_progress * Don't include progress coroutine in web response * Unconditionally reset progress task when show_progress finished * Fix race * Tweak, add tests * Address review comments * Improve error handling * Allow progress jobs to return anything * Add comment * Remove unneeded check * Change API according to discussion * Adjust typing
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Test the flow classes."""
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -7,7 +8,7 @@ import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from .common import (
|
||||
@@ -342,6 +343,169 @@ async def test_external_step(hass: HomeAssistant, manager) -> None:
|
||||
async def test_show_progress(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic."""
|
||||
manager.hass = hass
|
||||
events = []
|
||||
task_one_evt = asyncio.Event()
|
||||
task_two_evt = asyncio.Event()
|
||||
event_received_evt = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
events.append(event)
|
||||
event_received_evt.set()
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
data = None
|
||||
start_task_two = False
|
||||
progress_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def long_running_task_one() -> None:
|
||||
await task_one_evt.wait()
|
||||
self.start_task_two = True
|
||||
|
||||
async def long_running_task_two() -> None:
|
||||
await task_two_evt.wait()
|
||||
self.data = {"title": "Hello"}
|
||||
|
||||
if not task_one_evt.is_set():
|
||||
progress_action = "task_one"
|
||||
if not self.progress_task:
|
||||
self.progress_task = hass.async_create_task(long_running_task_one())
|
||||
elif not task_two_evt.is_set():
|
||||
progress_action = "task_two"
|
||||
if self.start_task_two:
|
||||
self.progress_task = hass.async_create_task(long_running_task_two())
|
||||
self.start_task_two = False
|
||||
if not task_one_evt.is_set() or not task_two_evt.is_set():
|
||||
return self.async_show_progress(
|
||||
step_id="init",
|
||||
progress_action=progress_action,
|
||||
progress_task=self.progress_task,
|
||||
)
|
||||
|
||||
return self.async_show_progress_done(next_step_id="finish")
|
||||
|
||||
async def async_step_finish(self, user_input=None):
|
||||
return self.async_create_entry(title=self.data["title"], data=self.data)
|
||||
|
||||
hass.bus.async_listen(
|
||||
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||
capture_events,
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_one"
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.async_progress_by_handler("test")) == 1
|
||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
# Set task one done and wait for event
|
||||
task_one_evt.set()
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 1
|
||||
assert events[0].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task_two"
|
||||
|
||||
# Set task two done and wait for event
|
||||
task_two_evt.set()
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 2 # 1 for task one and 1 for task two
|
||||
assert events[1].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Hello"
|
||||
|
||||
|
||||
async def test_show_progress_error(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic."""
|
||||
manager.hass = hass
|
||||
events = []
|
||||
event_received_evt = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def capture_events(event: Event) -> None:
|
||||
events.append(event)
|
||||
event_received_evt.set()
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
data = None
|
||||
progress_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def long_running_task() -> None:
|
||||
raise TypeError
|
||||
|
||||
if not self.progress_task:
|
||||
self.progress_task = hass.async_create_task(long_running_task())
|
||||
if self.progress_task and self.progress_task.done():
|
||||
if self.progress_task.exception():
|
||||
return self.async_show_progress_done(next_step_id="error")
|
||||
return self.async_show_progress_done(next_step_id="no_error")
|
||||
return self.async_show_progress(
|
||||
step_id="init", progress_action="task", progress_task=self.progress_task
|
||||
)
|
||||
|
||||
async def async_step_error(self, user_input=None):
|
||||
return self.async_abort(reason="error")
|
||||
|
||||
hass.bus.async_listen(
|
||||
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
|
||||
capture_events,
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||
assert result["progress_action"] == "task"
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.async_progress_by_handler("test")) == 1
|
||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
||||
|
||||
# Set task one done and wait for event
|
||||
await event_received_evt.wait()
|
||||
event_received_evt.clear()
|
||||
assert len(events) == 1
|
||||
assert events[0].data == {
|
||||
"handler": "test",
|
||||
"flow_id": result["flow_id"],
|
||||
"refresh": True,
|
||||
}
|
||||
|
||||
# Frontend refreshes the flow
|
||||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
assert result["reason"] == "error"
|
||||
|
||||
|
||||
async def test_show_progress_legacy(hass: HomeAssistant, manager) -> None:
|
||||
"""Test show progress logic.
|
||||
|
||||
This tests the deprecated version where the config flow is responsible for
|
||||
resuming the flow.
|
||||
"""
|
||||
manager.hass = hass
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
|
||||
Reference in New Issue
Block a user