AI Task integration (#145128)
* Add AI Task integration * Remove GenTextTaskType * Add AI Task prefs * Add action to LLM task * Remove WS command * Rename result to text for GenTextTaskResult * Apply suggestions from code review Co-authored-by: Allen Porter <allen.porter@gmail.com> * Add supported feature for generate text * Update const.py Co-authored-by: HarvsG <11440490+HarvsG@users.noreply.github.com> * Update homeassistant/components/ai_task/services.yaml Co-authored-by: HarvsG <11440490+HarvsG@users.noreply.github.com> * Use WS API to set preferences * Simplify pref storage * Simplify pref test * Update homeassistant/components/ai_task/services.yaml Co-authored-by: Allen Porter <allen.porter@gmail.com> --------- Co-authored-by: Allen Porter <allen.porter@gmail.com> Co-authored-by: HarvsG <11440490+HarvsG@users.noreply.github.com>
This commit is contained in:
1
tests/components/ai_task/__init__.py
Normal file
1
tests/components/ai_task/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the AI Task integration."""
|
||||
127
tests/components/ai_task/conftest.py
Normal file
127
tests/components/ai_task/conftest.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Test helpers for AI Task integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.ai_task import (
|
||||
DOMAIN,
|
||||
AITaskEntity,
|
||||
AITaskEntityFeature,
|
||||
GenTextTask,
|
||||
GenTextTaskResult,
|
||||
)
|
||||
from homeassistant.components.conversation import AssistantContent, ChatLog
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test"
|
||||
TEST_ENTITY_ID = "ai_task.test_task_entity"
|
||||
|
||||
|
||||
class MockAITaskEntity(AITaskEntity):
|
||||
"""Mock AI Task entity for testing."""
|
||||
|
||||
_attr_name = "Test Task Entity"
|
||||
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
super().__init__()
|
||||
self.mock_generate_text_tasks = []
|
||||
|
||||
async def _async_generate_text(
|
||||
self, task: GenTextTask, chat_log: ChatLog
|
||||
) -> GenTextTaskResult:
|
||||
"""Mock handling of generate text task."""
|
||||
self.mock_generate_text_tasks.append(task)
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(self.entity_id, "Mock result")
|
||||
)
|
||||
return GenTextTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
text="Mock result",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"""Mock a configuration entry for AI Task."""
|
||||
entry = MockConfigEntry(domain=TEST_DOMAIN, entry_id="mock-test-entry")
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_task_entity(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockAITaskEntity:
|
||||
"""Mock AI Task entity."""
|
||||
return MockAITaskEntity()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
):
|
||||
"""Initialize the AI Task integration with a mock entity."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [Platform.AI_TASK]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(
|
||||
config_entry, Platform.AI_TASK
|
||||
)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([mock_ai_task_entity])
|
||||
|
||||
mock_platform(
|
||||
hass,
|
||||
f"{TEST_DOMAIN}.{DOMAIN}",
|
||||
MockPlatform(async_setup_entry=async_setup_entry_platform),
|
||||
)
|
||||
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
22
tests/components/ai_task/snapshots/test_task.ambr
Normal file
22
tests/components/ai_task/snapshots/test_task.ambr
Normal file
@@ -0,0 +1,22 @@
|
||||
# serializer version: 1
|
||||
# name: test_run_text_task_updates_chat_log
|
||||
list([
|
||||
dict({
|
||||
'content': '''
|
||||
You are a Home Assistant expert and help users with their tasks.
|
||||
Current time is 15:59:00. Today's date is 2025-06-14.
|
||||
''',
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Test prompt',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'ai_task.test_task_entity',
|
||||
'content': 'Mock result',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
39
tests/components/ai_task/test_entity.py
Normal file
39
tests/components/ai_task/test_entity.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Tests for the AI Task entity model."""
|
||||
|
||||
from freezegun import freeze_time
|
||||
|
||||
from homeassistant.components.ai_task import async_generate_text
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@freeze_time("2025-06-08 16:28:13")
|
||||
async def test_state_generate_text(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the state of the AI Task entity is updated when generating text."""
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity is not None
|
||||
assert entity.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
hass,
|
||||
task_name="Test task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity.state == "2025-06-08T16:28:13+00:00"
|
||||
|
||||
assert mock_ai_task_entity.mock_generate_text_tasks
|
||||
task = mock_ai_task_entity.mock_generate_text_tasks[0]
|
||||
assert task.instructions == "Test prompt"
|
||||
84
tests/components/ai_task/test_http.py
Normal file
84
tests/components/ai_task/test_http.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Test the HTTP API for AI Task integration."""
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_ws_preferences(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test preferences via the WebSocket API."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
# Get initial preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": None,
|
||||
}
|
||||
|
||||
# Set preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Update an existing preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# No preferences set will preserve existing preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
84
tests/components/ai_task/test_init.py
Normal file
84
tests/components/ai_task/test_init.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .conftest import TEST_ENTITY_ID
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
|
||||
async def test_preferences_storage_load(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that AITaskPreferences are stored and loaded correctly."""
|
||||
preferences = AITaskPreferences(hass)
|
||||
await preferences.async_load()
|
||||
|
||||
# Initial state should be None for entity IDs
|
||||
for key in AITaskPreferences.KEYS:
|
||||
assert getattr(preferences, key) is None, f"Initial {key} should be None"
|
||||
|
||||
new_values = {key: f"ai_task.test_{key}" for key in AITaskPreferences.KEYS}
|
||||
|
||||
preferences.async_set_preferences(**new_values)
|
||||
|
||||
# Verify that current preferences object is updated
|
||||
for key, value in new_values.items():
|
||||
assert getattr(preferences, key) == value, (
|
||||
f"Current {key} should match set value"
|
||||
)
|
||||
|
||||
await flush_store(preferences._store)
|
||||
|
||||
# Create a new preferences instance to test loading from store
|
||||
new_preferences_instance = AITaskPreferences(hass)
|
||||
await new_preferences_instance.async_load()
|
||||
|
||||
for key in AITaskPreferences.KEYS:
|
||||
assert getattr(preferences, key) == getattr(new_preferences_instance, key), (
|
||||
f"Loaded {key} should match saved value"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("set_preferences", "msg_extra"),
|
||||
[
|
||||
(
|
||||
{"gen_text_entity_id": TEST_ENTITY_ID},
|
||||
{},
|
||||
),
|
||||
(
|
||||
{},
|
||||
{"entity_id": TEST_ENTITY_ID},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_text_service(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
set_preferences: dict[str, str | None],
|
||||
msg_extra: dict[str, str],
|
||||
) -> None:
|
||||
"""Test the generate text service."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences.async_set_preferences(**set_preferences)
|
||||
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_text",
|
||||
{
|
||||
"task_name": "Test Name",
|
||||
"instructions": "Test prompt",
|
||||
}
|
||||
| msg_extra,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert result["text"] == "Mock result"
|
||||
123
tests/components/ai_task/test_task.py
Normal file
123
tests/components/ai_task/test_task.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Test tasks for the AI Task integration."""
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_run_task_preferred_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test running a task with an unknown entity."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No entity_id provided and no preferred entity set"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.unknown",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": TEST_ENTITY_ID,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
|
||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
|
||||
async def test_run_text_task_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test running a text task with an unknown entity."""
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="AI Task entity ai_task.unknown_entity not found"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.unknown_entity",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_run_text_task_updates_chat_log(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that running a text task updates the chat log."""
|
||||
result = await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
assert chat_log.content == snapshot
|
||||
Reference in New Issue
Block a user