import os
import platform
import shlex
from enum import Enum
from pathlib import Path
from typing import Mapping

from loguru import logger

from imbue_core.agents.data_types.ids import ProjectID
from imbue_core.agents.data_types.ids import TaskID
from imbue_core.async_monkey_patches import log_exception
from imbue_core.common import generate_id
from imbue_core.concurrency_group import ConcurrencyGroup
from imbue_core.constants import ExceptionPriority
from imbue_core.processes.local_process import run_blocking
from imbue_core.subprocess_utils import ProcessError
from sculptor import version
from sculptor.cli.ssh_utils import ensure_local_sculptor_ssh_configured
from sculptor.database.models import TaskID
from sculptor.interfaces.environments.base import LocalDockerImage
from sculptor.interfaces.environments.errors import ImageConfigError
from sculptor.interfaces.environments.errors import ProviderError
from sculptor.interfaces.environments.provider_status import OkStatus
from sculptor.primitives.constants import USER_FACING_LOG_TYPE
from sculptor.primitives.ids import DockerContainerID
from sculptor.primitives.ids import DockerImageID
from sculptor.services.environment_service.api import TaskImageCleanupData
from sculptor.services.environment_service.environments.utils import get_docker_status
from sculptor.utils.timeout import log_runtime_decorator


@log_runtime_decorator()
def build_docker_image(
    dockerfile_path: Path,
    project_id: ProjectID,
    concurrency_group: ConcurrencyGroup,
    cached_repo_tarball_parent_directory: Path | None = None,
    tag: str | None = None,
    disable_cache: bool = False,
    secrets: Mapping[str, str] | None = None,
    build_path: Path | None = None,
    base_image_tag: str | None = None,
    container_user: str | None = None,
) -> LocalDockerImage:
    """Build a Docker image from a Dockerfile.

    build_path is a synonym for Docker's build context, which is an unnamed argument to docker build.
    container_user is the user from devcontainer.json's containerUser field, if any.
    """
    if not dockerfile_path.exists():
        raise FileNotFoundError(f"Dockerfile not found at {dockerfile_path}")

    if secrets is None:
        secrets = {}

    # Generate a unique tag if not provided
    if tag is None:
        tag = f"sculptor-image:{generate_id()[:8]}"

    # Build the Docker image
    build_command = [
        *("docker", "buildx", "build"),
        "--progress=plain",
        "--output=type=docker,compression=uncompressed",
        *("-f", str(dockerfile_path)),
        *("-t", tag),
        *("--build-arg", f"BUILT_FROM_SCULPTOR_VERSION={version.__version__}"),
        *("--build-arg", f"BUILT_FROM_SCULPTOR_GIT_HASH={version.__git_sha__}"),
        # TODO: Get rid of these when we can.
        *("--build-arg", f"USER_UID={os.getuid()}"),
        *("--build-arg", f"GROUP_GID={os.getgid()}"),
    ]
    if cached_repo_tarball_parent_directory:
        build_command.extend(("--build-context", f"imbue_user_repo={cached_repo_tarball_parent_directory}"))

    ssh_keypair_dir = ensure_local_sculptor_ssh_configured()
    build_command.extend(("--build-context", f"ssh_keypair_dir={ssh_keypair_dir}"))

    if base_image_tag:
        build_command.extend(("--build-arg", f"BASE_IMAGE={base_image_tag}"))

    if container_user:
        build_command.extend(("--build-arg", f"CONTAINER_USER={container_user}"))

    if disable_cache:
        build_command.append("--no-cache")

    build_path = build_path or dockerfile_path.parent
    build_command.append(str(build_path))

    logger.info("Building Docker image with tag {}", tag)

    build_command_string = " ".join(shlex.quote(arg) for arg in build_command)
    logger.debug("Building Docker image with build_path={}:\n{}", build_path, build_command_string)

    try:
        concurrency_group.run_process_to_completion(
            command=build_command,
            on_output=lambda line, is_stderr: logger.debug(line.strip()),
            cwd=build_path,
            trace_log_context={
                "sandbox_path": str(build_path),
                "log_type": USER_FACING_LOG_TYPE,
            },
            env={**os.environ, **secrets},
        )
    except ProcessError as e:
        error_msg = f"Docker build failed - is your Docker up-to-date? Exit code {e.returncode}: {build_command_string}\nstdout=\n{e.stdout}\nstderr=\n{e.stderr}"
        if "ERROR: failed to solve" in e.stderr:
            # NOTE: this might not be the best way to distinguish between image config errors and other errors
            # but it's the best we can do for now
            raise ImageConfigError(error_msg) from e
        raise ProviderError(error_msg) from e

    # Get the image ID
    inspect_result = concurrency_group.run_process_to_completion(
        command=["docker", "inspect", "-f", "{{.Id}}", tag],
        is_checked=False,
    )

    if inspect_result.returncode != 0:
        raise ProviderError(f"Failed to inspect built image: {inspect_result.stderr}")

    docker_image_id = inspect_result.stdout.strip()

    # Save to database
    full_id = DockerImageID(docker_image_id)

    logger.info("Built Docker image {} with tag {}", full_id, tag)
    return LocalDockerImage(image_id=full_id, docker_image_tag=tag, project_id=project_id)


def delete_docker_image_and_any_stopped_containers(
    image_id: str, concurrency_group: ConcurrencyGroup
) -> tuple[bool, list[DockerContainerID]]:
    """Delete a Docker image by image ID."""
    deleted_container_ids: list[DockerContainerID] = []
    # first delete all *stopped* docker containers that were created from this image
    try:
        container_ids = (
            concurrency_group.run_process_to_completion(
                command=["docker", "ps", "-a", "-q", "-f", "status=exited", "-f", f"ancestor={image_id}"],
            )
            .stdout.strip()
            .splitlines(keepends=False)
        )
    # TODO: probably need some better error handling here
    except ProcessError as e:
        log_exception(
            e, "Failed to list containers for {image_id}", priority=ExceptionPriority.LOW_PRIORITY, image_id=image_id
        )
        return False, deleted_container_ids

    for container_id in container_ids:
        try:
            concurrency_group.run_process_to_completion(command=["docker", "rm", container_id])
            deleted_container_ids.append(DockerContainerID(container_id))
            logger.debug("Successfully deleted stopped container {} for image {}", container_id, image_id)
        except ProcessError as e:
            log_exception(
                e,
                "Failed to delete stopped containers for image {image_id}",
                priority=ExceptionPriority.LOW_PRIORITY,
                image_id=image_id,
            )
            return False, deleted_container_ids

    try:
        # The only time we want to delete an image is when it is genuinely unused; i.e.
        # not being used by a current running container. The docker rmi command fails when
        # it is asked to delete an image used by a currently running container, while allowing
        # you to delete outdated snapshots for currently running containers.

        concurrency_group.run_process_to_completion(command=["docker", "rmi", image_id])
        logger.debug("Successfully deleted Docker image: {}", image_id)
        return True, deleted_container_ids
    except ProcessError as e:
        image_still_exists_against_our_wishes = concurrency_group.run_process_to_completion(
            command=["docker", "inspect", image_id], is_checked=False
        )
        if image_still_exists_against_our_wishes.returncode != 0:
            return True, deleted_container_ids
        else:
            if "image is being used by running container" in e.stderr:
                pass
            else:
                log_exception(e, "Failed to delete Docker image {image_id}", image_id=image_id)
            return False, deleted_container_ids
    except Exception as e:
        log_exception(e, "Error deleting Docker image {image_id}", image_id=image_id)
        return False, deleted_container_ids


def _get_image_ids_with_running_containers(concurrency_group: ConcurrencyGroup) -> tuple[str, ...]:
    try:
        container_ids_result = concurrency_group.run_process_to_completion(command=("docker", "ps", "--quiet"))
        container_ids = container_ids_result.stdout.strip().splitlines()
        if len(container_ids) == 0:
            return ()
        image_ids_result = concurrency_group.run_process_to_completion(
            command=(
                "docker",
                "inspect",
                "--format={{.Image}}",
                *container_ids,
            )
        )
    except ProcessError as e:
        health_status = get_docker_status(concurrency_group)
        if not isinstance(health_status, OkStatus):
            logger.debug("Docker seems to be down, cannot list running containers")
            details_msg = f" (details: {health_status.details})" if health_status.details else ""
            raise ProviderError(f"Provider is unavailable: {health_status.message}{details_msg}") from e
        else:
            log_exception(
                e, "Error getting image IDs with running containers", priority=ExceptionPriority.LOW_PRIORITY
            )
            return ()

    active_image_ids: set[str] = set()
    for line in image_ids_result.stdout.strip().splitlines():
        line = line.strip()
        if line:
            active_image_ids.add(line)
    return tuple(active_image_ids)


class DeletionTier(Enum):
    # if an image is being used in multiple tasks, we take the lowest tier of the tasks

    # never delete: images on running containers or the latest image of a task
    NEVER_DELETE = 0
    # rarely delete: historical images on active tasks that are not being used by a running container
    RARELY_DELETE = 1
    # sometimes delete: historical images on archived tasks that are not being used by a running container
    SOMETIMES_DELETE = 2
    # always delete: images for deleted tasks
    ALWAYS_DELETE = 3


def _classify_image_tier(image_id: str, associated_task_metadata: TaskImageCleanupData) -> DeletionTier:
    if associated_task_metadata.is_deleted:
        return DeletionTier.ALWAYS_DELETE
    if image_id == associated_task_metadata.last_image_id:
        return DeletionTier.NEVER_DELETE
    if associated_task_metadata.is_archived:
        return DeletionTier.SOMETIMES_DELETE
    return DeletionTier.RARELY_DELETE


def _get_task_ids_by_image_id(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
) -> dict[str, list[TaskID]]:
    task_ids_by_image_id: dict[str, list[TaskID]] = dict()
    for task_id, task_metadata in task_metadata_by_task_id.items():
        for image_id in task_metadata.all_image_ids:
            task_ids_by_image_id.setdefault(image_id, []).append(task_id)
    return task_ids_by_image_id


def _get_tier_by_image_id(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
    active_image_ids: tuple[str, ...],
) -> dict[str, DeletionTier]:
    tier_by_image_id: dict[str, DeletionTier] = dict()
    task_ids_by_image_id = _get_task_ids_by_image_id(task_metadata_by_task_id)

    for image_id, task_ids in task_ids_by_image_id.items():
        if image_id in active_image_ids:
            logger.debug("Image {} is in active image IDs - never delete", image_id)
            tier_by_image_id[image_id] = DeletionTier.NEVER_DELETE
        else:
            tiers = []
            for task_id in task_ids:
                task_metadata = task_metadata_by_task_id[task_id]
                tiers.append(_classify_image_tier(image_id=image_id, associated_task_metadata=task_metadata))
            if any(tier == DeletionTier.NEVER_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.NEVER_DELETE
            elif any(tier == DeletionTier.RARELY_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.RARELY_DELETE
            elif any(tier == DeletionTier.SOMETIMES_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.SOMETIMES_DELETE
            else:
                tier_by_image_id[image_id] = DeletionTier.ALWAYS_DELETE
            logger.debug("Image {} has been assigned tier {}", image_id, tier_by_image_id[image_id])
    return tier_by_image_id


def _get_current_image_ids(concurrency_group: ConcurrencyGroup) -> tuple[str, ...]:
    try:
        result = concurrency_group.run_process_to_completion(command=("docker", "images", "--quiet", "--no-trunc"))
    except ProcessError as e:
        health_status = get_docker_status(concurrency_group)
        if not isinstance(health_status, OkStatus):
            logger.debug("Docker seems to be down, cannot list images")
            details_msg = f" (details: {health_status.details})" if health_status.details else ""
            raise ProviderError(f"Provider is unavailable: {health_status.message}{details_msg}") from e
        else:
            raise
    image_ids = set()
    for line in result.stdout.strip().splitlines():
        line = line.strip()
        if line:
            image_ids.add(line)
    return tuple(image_ids)


def _extend_image_ids_with_similar_hashes(image_ids: tuple[str, ...]) -> tuple[str, ...]:
    return tuple({*image_ids, *(image_id.split(":", 1)[-1] for image_id in image_ids if ":" in image_id)})


def get_image_ids_to_delete(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
    minimum_deletion_tier: DeletionTier,
    concurrency_group: ConcurrencyGroup,
) -> tuple[str, ...]:
    # TODO(sam): The task metadata pulls the image IDs from agent logs which means we may need to
    # support image IDs without the "sha256:" prefix.
    # Right now, we try to handle both cases. But we should have better hash-matching logic.
    existing_image_ids = _extend_image_ids_with_similar_hashes(_get_current_image_ids(concurrency_group))
    active_image_ids = _extend_image_ids_with_similar_hashes(_get_image_ids_with_running_containers(concurrency_group))
    return _calculate_image_ids_to_delete(
        task_metadata_by_task_id, active_image_ids, existing_image_ids, minimum_deletion_tier
    )


def _calculate_image_ids_to_delete(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
    active_image_ids: tuple[str, ...],
    existing_image_ids: tuple[str, ...],
    minimum_deletion_tier: DeletionTier,
) -> tuple[str, ...]:
    tier_by_image_id = _get_tier_by_image_id(task_metadata_by_task_id, active_image_ids)
    image_ids = set()
    for image_id, tier in tier_by_image_id.items():
        if tier.value > minimum_deletion_tier.value and image_id in existing_image_ids:
            # only attempt to delete images that are above the minimum deletion tier and still exist in the system
            logger.debug("Adding image {} to deletion list", image_id)
            image_ids.add(image_id)
    return tuple(image_ids)


def get_platform_architecture() -> str:
    """
    Determine the platform architecture for Docker images.

    Returns:
        Platform name ("amd64" or "arm64")

    Examples:
        >>> get_platform_architecture() in ["amd64", "arm64"]
        True
    """
    # TODO: Add unit test that exercises the docker info command path when docker is available
    # NOTE(bowei): use the docker info, in case somehow it's different from the host
    # Fall back to platform.machine() if docker is not available
    arch = platform.machine().lower()
    try:
        arch = run_blocking(["docker", "info", "--format", "{{.Architecture}}"]).stdout.strip() or arch
    except ProcessError:
        # Docker not available or command failed, use fallback
        pass

    if arch == "x86_64":
        return "amd64"
    elif arch == "aarch64" or arch == "arm64":
        return "arm64"
    else:
        logger.info(f"Unknown architecture {arch}, defaulting to amd64")
        return "amd64"
