import tempfile
from pathlib import Path
from typing import Generator
from typing import cast

import pytest

from imbue_core.agents.data_types.ids import ProjectID
from imbue_core.agents.data_types.ids import TaskID
from imbue_core.concurrency_group import ConcurrencyGroup
from imbue_core.sculptor.state.messages import Message
from sculptor.config.settings import SculptorSettings
from sculptor.database.models import AgentTaskInputsV1
from sculptor.database.models import Project
from sculptor.database.models import Task
from sculptor.interfaces.agents.agent import HelloAgentConfig
from sculptor.interfaces.environments.base import LocalEnvironmentConfig
from sculptor.interfaces.environments.base import LocalImageConfig
from sculptor.primitives.ids import OrganizationReference
from sculptor.primitives.ids import RequestID
from sculptor.primitives.ids import UserReference
from sculptor.service_collections.service_collection import CompleteServiceCollection
from sculptor.services.environment_service.environments.image_tags import ImageMetadataV1
from sculptor.services.environment_service.environments.local_environment import LocalEnvironment
from sculptor.services.task_service.data_types import ServiceCollectionForTask


@pytest.fixture
def environment_config() -> LocalEnvironmentConfig:
    return LocalEnvironmentConfig()


@pytest.fixture
def project() -> Project:
    return Project(object_id=ProjectID(), name="Test Project", organization_reference=OrganizationReference("org_123"))


@pytest.fixture
def local_task(project: Project, environment_config: LocalEnvironmentConfig, tmp_path: Path) -> Task:
    return Task(
        object_id=TaskID(),
        organization_reference=project.organization_reference,
        user_reference=UserReference("usr_123"),
        project_id=project.object_id,
        input_data=AgentTaskInputsV1(
            agent_config=HelloAgentConfig(),
            image_config=LocalImageConfig(code_directory=tmp_path),
            environment_config=environment_config,
            git_hash="initialhash",
            initial_branch="main",
            is_git_state_clean=False,
        ),
        parent_task_id=None,
    )


# Override the test_settings fixture used in other fixtures (e.g. test_service_collection) in the higher-level conftest.py
@pytest.fixture
def test_settings(test_settings: SculptorSettings) -> SculptorSettings:
    return test_settings.model_copy(update={"IS_CHECKS_ENABLED": True, "DOCKER_PROVIDER_ENABLED": False})


@pytest.fixture
def services(
    test_service_collection: CompleteServiceCollection,
    test_root_concurrency_group: ConcurrencyGroup,
    local_task: Task,
    project: Project,
) -> Generator[ServiceCollectionForTask, None, None]:
    with test_service_collection.data_model_service.open_transaction(RequestID()) as transaction:
        transaction.upsert_project(project)
        test_service_collection.task_service.create_task(local_task, transaction)
    yield cast(ServiceCollectionForTask, test_service_collection)


@pytest.fixture
def environment(
    tmp_path: Path,
    environment_config: LocalEnvironmentConfig,
    services: ServiceCollectionForTask,
    project: Project,
    initial_commit_repo: tuple[Path, str],
    test_root_concurrency_group: ConcurrencyGroup,
) -> Generator[LocalEnvironment, None, None]:
    code_dir, _ = initial_commit_repo
    image_config = LocalImageConfig(code_directory=code_dir)
    with tempfile.TemporaryDirectory() as tmp_dir:
        image = services.environment_service.ensure_image(
            image_config,
            project.object_id,
            {},
            code_dir,
            Path(tmp_dir),
            image_metadata=ImageMetadataV1.from_testing(),
        )
        with services.environment_service.generate_environment(
            image=image,
            project_id=project.object_id,
            concurrency_group=test_root_concurrency_group,
            config=environment_config,
        ) as environment:
            assert isinstance(environment, LocalEnvironment)
            yield environment


def get_all_messages_for_task(task_id: TaskID, services: ServiceCollectionForTask) -> list[Message]:
    all_messages: list[Message] = []
    with services.task_service.subscribe_to_task(task_id) as queue:
        while queue.qsize() > 0:
            all_messages.append(queue.get_nowait())
    # remove the initial task state message
    all_messages.pop(0)
    return all_messages
