import os
import platform
import shlex
from enum import Enum
from pathlib import Path
from typing import Literal
from typing import Mapping
from typing import Sequence

import humanfriendly
from humanfriendly import InvalidSize
from loguru import logger
from pydantic import BaseModel

from imbue_core.agents.data_types.ids import ProjectID
from imbue_core.agents.data_types.ids import TaskID
from imbue_core.async_monkey_patches import log_exception
from imbue_core.common import generate_id
from imbue_core.concurrency_group import ConcurrencyGroup
from imbue_core.constants import ExceptionPriority
from imbue_core.processes.local_process import run_blocking
from imbue_core.sculptor.telemetry import PosthogEventModel
from imbue_core.sculptor.telemetry import PosthogEventPayload
from imbue_core.sculptor.telemetry import emit_posthog_event
from imbue_core.sculptor.telemetry import with_consent
from imbue_core.sculptor.telemetry_constants import ConsentLevel
from imbue_core.sculptor.telemetry_constants import ProductComponent
from imbue_core.sculptor.telemetry_constants import SculptorPosthogEvent
from imbue_core.subprocess_utils import ProcessError
from sculptor import version
from sculptor.cli.ssh_utils import ensure_local_sculptor_ssh_configured
from sculptor.config.settings import SculptorSettings
from sculptor.database.models import TaskID
from sculptor.interfaces.environments.base import LocalDockerImage
from sculptor.interfaces.environments.errors import ImageConfigError
from sculptor.interfaces.environments.errors import ProviderError
from sculptor.interfaces.environments.provider_status import OkStatus
from sculptor.primitives.constants import USER_FACING_LOG_TYPE
from sculptor.primitives.executor import ObservableThreadPoolExecutor
from sculptor.primitives.ids import DockerContainerID
from sculptor.primitives.ids import DockerImageID
from sculptor.services.environment_service.api import TaskImageCleanupData
from sculptor.services.environment_service.environments.constants import SNAPSHOT_SUFFIX
from sculptor.services.environment_service.environments.constants import USER_IMAGE_SUFFIX
from sculptor.services.environment_service.environments.constants import get_standard_environment_prefix
from sculptor.services.environment_service.environments.docker_environment import get_unique_snapshot_size_bytes
from sculptor.services.environment_service.environments.utils import get_docker_status
from sculptor.utils.timeout import log_runtime_decorator


@log_runtime_decorator()
def build_docker_image(
    dockerfile_path: Path,
    project_id: ProjectID,
    concurrency_group: ConcurrencyGroup,
    cached_repo_tarball_parent_directory: Path | None = None,
    tag: str | None = None,
    disable_cache: bool = False,
    secrets: Mapping[str, str] | None = None,
    build_path: Path | None = None,
    base_image_tag: str | None = None,
    container_user: str | None = None,
) -> LocalDockerImage:
    """Build a Docker image from a Dockerfile.

    build_path is a synonym for Docker's build context, which is an unnamed argument to docker build.
    container_user is the user from devcontainer.json's containerUser field, if any.
    """
    if not dockerfile_path.exists():
        raise FileNotFoundError(f"Dockerfile not found at {dockerfile_path}")

    if secrets is None:
        secrets = {}

    # Generate a unique tag if not provided
    if tag is None:
        tag = f"sculptor-image:{generate_id()[:8]}"

    # Build the Docker image
    build_command = [
        *("docker", "buildx", "build"),
        "--progress=plain",
        "--output=type=docker,compression=uncompressed",
        *("-f", str(dockerfile_path)),
        *("-t", tag),
        *("--build-arg", f"BUILT_FROM_SCULPTOR_VERSION={version.__version__}"),
        *("--build-arg", f"BUILT_FROM_SCULPTOR_GIT_HASH={version.__git_sha__}"),
    ]
    if cached_repo_tarball_parent_directory:
        build_command.extend(("--build-context", f"imbue_user_repo={cached_repo_tarball_parent_directory}"))

    ssh_keypair_dir = ensure_local_sculptor_ssh_configured()
    build_command.extend(("--build-context", f"ssh_keypair_dir={ssh_keypair_dir}"))

    if base_image_tag:
        build_command.extend(("--build-arg", f"BASE_IMAGE={base_image_tag}"))

    if container_user:
        build_command.extend(("--build-arg", f"CONTAINER_USER={container_user}"))

    if disable_cache:
        build_command.append("--no-cache")

    build_path = build_path or dockerfile_path.parent
    build_command.append(str(build_path))

    logger.info("Building Docker image with tag {}", tag)

    build_command_string = " ".join(shlex.quote(arg) for arg in build_command)
    logger.debug("Building Docker image with build_path={}:\n{}", build_path, build_command_string)

    try:
        concurrency_group.run_process_to_completion(
            command=build_command,
            on_output=lambda line, is_stderr: logger.debug(line.strip()),
            cwd=build_path,
            trace_log_context={
                "sandbox_path": str(build_path),
                "log_type": USER_FACING_LOG_TYPE,
            },
            env={**os.environ, **secrets},
        )
    except ProcessError as e:
        error_msg = f"Docker build failed - is your Docker up-to-date? Exit code {e.returncode}: {build_command_string}\nstdout=\n{e.stdout}\nstderr=\n{e.stderr}"
        if "ERROR: failed to solve" in e.stderr:
            # NOTE: this might not be the best way to distinguish between image config errors and other errors
            # but it's the best we can do for now
            raise ImageConfigError(error_msg) from e
        raise ProviderError(error_msg) from e

    # Get the image ID
    inspect_result = concurrency_group.run_process_to_completion(
        command=["docker", "inspect", "-f", "{{.Id}}", tag],
        is_checked=False,
    )

    if inspect_result.returncode != 0:
        raise ProviderError(f"Failed to inspect built image: {inspect_result.stderr}")

    docker_image_id = inspect_result.stdout.strip()

    # Save to database
    full_id = DockerImageID(docker_image_id)

    logger.info("Built Docker image {} with tag {}", full_id, tag)
    return LocalDockerImage(image_id=full_id, project_id=project_id)


@log_runtime_decorator("initializeCommands")
def run_initialize_command(
    initialize_command: str | list[str] | Mapping[str, str | list[str]],
    devcontainer_path: Path,
    concurrency_group: ConcurrencyGroup,
) -> None:
    """Run `initializeCommand` spec from a devcontainer.json.

    initialize_command is the command(s) to run.
    devcontainer_path is the path to the devcontainer.json file.
    concurrency_group is the concurrency group to run the commands in.
    """

    commands = _preprocess_initialize_command(initialize_command)
    cwd = _repository_root_from_devcontainer_path(devcontainer_path)
    try:
        with concurrency_group.make_concurrency_group("initializeCommands") as initialize_commands:
            for command_name, command in commands:
                initialize_commands.run_process_in_background(
                    command=command,
                    on_output=lambda line, is_stderr, name=command_name: logger.debug(
                        f"#[{name:>15}]: {line.strip()}"
                    ),
                    cwd=cwd,
                    trace_log_context={
                        "sandbox_path": str(cwd),
                        "log_type": USER_FACING_LOG_TYPE,
                    },
                    env={**os.environ},
                )

    except ProcessError as e:
        error_msg = f"running `initializeCommand` hooks failed? Path {cwd}, Exit code {e.returncode}: {commands}\nstdout=\n{e.stdout}\nstderr=\n{e.stderr}"
        raise ImageConfigError(error_msg) from e


def _cmd_as_list(cmd) -> list[str]:
    if isinstance(cmd, str):
        return shlex.split(cmd)
    if isinstance(cmd, list) and all(isinstance(i, str) for i in cmd):
        return cmd
    raise TypeError(f"Command {cmd} is not a list or string")


def _preprocess_initialize_command(
    initialize_command: str | list[str] | Mapping[str, str | list[str]],
) -> list[tuple[str, Sequence[str]]]:
    """Preprocess `initializeCommands` spec from a devcontainer.json.

    pre-process inputs into name -> ['cmd', *'arg'] format
    """
    commands: list[tuple[str, Sequence[str]]] = []

    if isinstance(initialize_command, dict):
        for command_name, command in initialize_command.items():
            commands.append((command_name, _cmd_as_list(command)))
    else:
        commands.append(("initializeCommand", _cmd_as_list(initialize_command)))

    return commands


def _repository_root_from_devcontainer_path(devcontainer_path: Path) -> Path:
    """
    devcontainer_path has one of these forms:.
       .devcontainer/devcontainer.json
       .devcontainer.json
       .devcontainer/<folder>/devcontainer.json (where <folder> is a sub-folder, one level deep)

    we want to find out the folder containing each of these forms.
    """
    # If the file is devcontainer.json, go up until we're above .devcontainer
    if devcontainer_path.name == "devcontainer.json":
        if devcontainer_path.parent.name == ".devcontainer":
            return devcontainer_path.parent.parent
            # Cases: <root>/.devcontainer/devcontainer.json
            # or <root>/.devcontainer/<folder>/devcontainer.json
        if devcontainer_path.parent.parent.name == ".devcontainer":
            return devcontainer_path.parent.parent.parent

    if devcontainer_path.name == ".devcontainer.json":
        # Case: <root>/.devcontainer.json
        return devcontainer_path.parent

    raise ImageConfigError(f"devcontainer.json path is invalid {devcontainer_path}")


def delete_docker_image_and_any_stopped_containers(
    image_id: str, concurrency_group: ConcurrencyGroup
) -> tuple[bool, list[DockerContainerID]]:
    """Delete a Docker image by image ID."""
    deleted_container_ids: list[DockerContainerID] = []
    # first delete all *stopped* docker containers that were created from this image
    try:
        container_ids = (
            concurrency_group.run_process_to_completion(
                command=["docker", "ps", "-a", "-q", "-f", "status=exited", "-f", f"ancestor={image_id}"],
            )
            .stdout.strip()
            .splitlines(keepends=False)
        )
    # TODO: probably need some better error handling here
    except ProcessError as e:
        log_exception(
            e, "Failed to list containers for {image_id}", priority=ExceptionPriority.LOW_PRIORITY, image_id=image_id
        )
        return False, deleted_container_ids

    for container_id in container_ids:
        try:
            concurrency_group.run_process_to_completion(command=["docker", "rm", container_id])
            deleted_container_ids.append(DockerContainerID(container_id))
            logger.debug("Successfully deleted stopped container {} for image {}", container_id, image_id)
        except ProcessError as e:
            log_exception(
                e,
                "Failed to delete stopped containers for image {image_id}",
                priority=ExceptionPriority.LOW_PRIORITY,
                image_id=image_id,
            )
            return False, deleted_container_ids

    try:
        # The only time we want to delete an image is when it is genuinely unused; i.e.
        # not being used by a current running container. The docker rmi command fails when
        # it is asked to delete an image used by a currently running container, while allowing
        # you to delete outdated snapshots for currently running containers.

        concurrency_group.run_process_to_completion(command=["docker", "rmi", image_id])
        logger.debug("Successfully deleted Docker image: {}", image_id)
        return True, deleted_container_ids
    except ProcessError as e:
        image_still_exists_against_our_wishes = concurrency_group.run_process_to_completion(
            command=["docker", "inspect", image_id], is_checked=False
        )
        if image_still_exists_against_our_wishes.returncode != 0:
            return True, deleted_container_ids
        else:
            if "image is being used by running container" in e.stderr:
                pass
            else:
                log_exception(e, "Failed to delete Docker image {image_id}", image_id=image_id)
            return False, deleted_container_ids
    except Exception as e:
        log_exception(e, "Error deleting Docker image {image_id}", image_id=image_id)
        return False, deleted_container_ids


def get_image_ids_with_running_containers(concurrency_group: ConcurrencyGroup) -> tuple[str, ...]:
    try:
        container_ids_result = concurrency_group.run_process_to_completion(command=("docker", "ps", "--quiet"))
        container_ids = container_ids_result.stdout.strip().splitlines()
        if len(container_ids) == 0:
            return ()
        image_ids_result = concurrency_group.run_process_to_completion(
            command=(
                "docker",
                "inspect",
                "--format={{.Image}}",
                *container_ids,
            )
        )
    except ProcessError as e:
        health_status = get_docker_status(concurrency_group)
        if not isinstance(health_status, OkStatus):
            logger.debug("Docker seems to be down, cannot list running containers")
            details_msg = f" (details: {health_status.details})" if health_status.details else ""
            raise ProviderError(f"Provider is unavailable: {health_status.message}{details_msg}") from e
        else:
            log_exception(
                e, "Error getting image IDs with running containers", priority=ExceptionPriority.LOW_PRIORITY
            )
            return ()

    active_image_ids: set[str] = set()
    for line in image_ids_result.stdout.strip().splitlines():
        line = line.strip()
        if line:
            active_image_ids.add(line)
    return tuple(active_image_ids)


class DeletionTier(Enum):
    # if an image is being used in multiple tasks, we take the lowest tier of the tasks

    # never delete: images on running containers or the latest image of a task
    NEVER_DELETE = 0
    # rarely delete: historical images on active tasks that are not being used by a running container
    RARELY_DELETE = 1
    # sometimes delete: historical images on archived tasks that are not being used by a running container
    SOMETIMES_DELETE = 2
    # always delete: images for deleted tasks
    ALWAYS_DELETE = 3


def _classify_image_tier(image_id: str, associated_task_metadata: TaskImageCleanupData) -> DeletionTier:
    if associated_task_metadata.is_deleted:
        return DeletionTier.ALWAYS_DELETE
    if image_id == associated_task_metadata.last_image_id:
        return DeletionTier.NEVER_DELETE
    if associated_task_metadata.is_archived:
        return DeletionTier.SOMETIMES_DELETE
    return DeletionTier.RARELY_DELETE


def _get_task_ids_by_image_id(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
) -> dict[str, list[TaskID]]:
    task_ids_by_image_id: dict[str, list[TaskID]] = dict()
    for task_id, task_metadata in task_metadata_by_task_id.items():
        for image_id in task_metadata.all_image_ids:
            task_ids_by_image_id.setdefault(image_id, []).append(task_id)
    return task_ids_by_image_id


def _get_tier_by_image_id(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
    active_image_ids: tuple[str, ...],
) -> dict[str, DeletionTier]:
    tier_by_image_id: dict[str, DeletionTier] = dict()
    task_ids_by_image_id = _get_task_ids_by_image_id(task_metadata_by_task_id)

    for image_id, task_ids in task_ids_by_image_id.items():
        if image_id in active_image_ids:
            logger.debug("Image {} is in active image IDs - never delete", image_id)
            tier_by_image_id[image_id] = DeletionTier.NEVER_DELETE
        else:
            tiers = []
            for task_id in task_ids:
                task_metadata = task_metadata_by_task_id[task_id]
                tiers.append(_classify_image_tier(image_id=image_id, associated_task_metadata=task_metadata))
            if any(tier == DeletionTier.NEVER_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.NEVER_DELETE
            elif any(tier == DeletionTier.RARELY_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.RARELY_DELETE
            elif any(tier == DeletionTier.SOMETIMES_DELETE for tier in tiers):
                tier_by_image_id[image_id] = DeletionTier.SOMETIMES_DELETE
            else:
                tier_by_image_id[image_id] = DeletionTier.ALWAYS_DELETE
            logger.debug("Image {} has been assigned tier {}", image_id, tier_by_image_id[image_id])
    return tier_by_image_id


class ImageInfo(BaseModel):
    repository: str
    tag: str
    id: str
    created_at: str

    @property
    def category(self) -> Literal["USER", "WRAPPED", "SNAPSHOT"]:
        if self.tag.endswith(USER_IMAGE_SUFFIX):
            return "USER"
        if self.repository.endswith(SNAPSHOT_SUFFIX):
            return "SNAPSHOT"
        return "WRAPPED"


def get_images_disk_usage_bytes(concurrency_group: ConcurrencyGroup) -> int | None:
    try:
        result = concurrency_group.run_process_to_completion(
            ["docker", "system", "df", "--format={{.Type}} {{.Size}}"],
        )
    except ProcessError as e:
        raise ProviderError("Failed to run docker system df")

    for line in result.stdout.strip().splitlines():
        if line.startswith("Images "):
            try:
                return humanfriendly.parse_size(line.split()[1])
            except InvalidSize:
                return None

    return None


def get_current_sculptor_images_info(
    concurrency_group: ConcurrencyGroup, settings: SculptorSettings
) -> tuple[ImageInfo, ...]:
    # TODO: when we implement garbage collection style image cleanup logic, we want to only get images with a prefix
    #  indicating that it was generated by this instance of scupltor (as opposed to the other sculptor in SoS).
    #  There should be a system for generating sculptor-instance-specific IDs and adding those to image repository names
    #  We also need to include control planes eventually.
    if settings.TESTING.CONTAINER_PREFIX is not None:
        sculptor_image_prefix = settings.TESTING.CONTAINER_PREFIX
    else:
        sculptor_image_prefix = get_standard_environment_prefix()
    try:
        result = concurrency_group.run_process_to_completion(
            command=(
                "docker",
                "images",
                "--quiet",
                "--no-trunc",
                "--filter",
                f"reference={sculptor_image_prefix}*",
                "--format={{.Repository}} {{.Tag}} {{.ID}} {{.CreatedAt}}",
            )
        )
    except ProcessError as e:
        health_status = get_docker_status(concurrency_group)
        if not isinstance(health_status, OkStatus):
            logger.debug("Docker seems to be down, cannot list images")
            details_msg = f" (details: {health_status.details})" if health_status.details else ""
            raise ProviderError(f"Provider is unavailable: {health_status.message}{details_msg}") from e
        else:
            raise
    image_infos = []
    for line in result.stdout.strip().splitlines():
        line = line.strip()
        repo, tag, id, created_at = line.split(maxsplit=3)
        if line:
            image_infos.append(
                ImageInfo(
                    repository=repo,
                    tag=tag,
                    id=id,
                    created_at=created_at,
                )
            )
    return tuple(image_infos)


def extend_image_ids_with_similar_hashes(image_ids: Sequence[str]) -> tuple[str, ...]:
    return tuple({*image_ids, *(image_id.split(":", 1)[-1] for image_id in image_ids if ":" in image_id)})


class ImageInfoPayload(PosthogEventPayload):
    snapshot_count: int = with_consent(
        ConsentLevel.PRODUCT_ANALYTICS, description="Number of sculptor-created snapshot images"
    )
    total_snapshot_bytes: int = with_consent(
        ConsentLevel.PRODUCT_ANALYTICS, description="Space used by sculptor-created snapshot images"
    )
    # TODO: add dangling_image_count, which uses garbage collection logic to find images not associatd with active tasks
    total_image_bytes: int | None = with_consent(
        ConsentLevel.PRODUCT_ANALYTICS,
        description="Space used by all docker images, not just sculptor-created images",
    )


def record_images_to_posthog(concurrency_group: ConcurrencyGroup, image_infos: Sequence[ImageInfo]) -> None:
    snapshot_image_infos = tuple(image_info for image_info in image_infos if image_info.category == "SNAPSHOT")
    with ObservableThreadPoolExecutor(
        concurrency_group, max_workers=16, thread_name_prefix="ImageInspector"
    ) as executor:
        snapshot_sizes = executor.map(
            lambda image_info: get_unique_snapshot_size_bytes(concurrency_group, image_info.id), snapshot_image_infos
        )
        total_image_bytes_future = executor.submit(get_images_disk_usage_bytes, concurrency_group)
        payload = ImageInfoPayload(
            total_snapshot_bytes=sum(snapshot_sizes),
            snapshot_count=len(snapshot_image_infos),
            total_image_bytes=total_image_bytes_future.result(),
        )
        emit_posthog_event(
            PosthogEventModel(
                name=SculptorPosthogEvent.IMAGE_INFORMATION,
                component=ProductComponent.CROSS_COMPONENT,
                payload=payload,
            )
        )


def calculate_image_ids_to_delete(
    task_metadata_by_task_id: Mapping[TaskID, TaskImageCleanupData],
    active_image_ids: tuple[str, ...],
    existing_image_ids: tuple[str, ...],
    minimum_deletion_tier: DeletionTier,
) -> tuple[str, ...]:
    tier_by_image_id = _get_tier_by_image_id(task_metadata_by_task_id, active_image_ids)
    image_ids = set()
    for image_id, tier in tier_by_image_id.items():
        if tier.value > minimum_deletion_tier.value and image_id in existing_image_ids:
            # only attempt to delete images that are above the minimum deletion tier and still exist in the system
            logger.debug("Adding image {} to deletion list", image_id)
            image_ids.add(image_id)
    return tuple(image_ids)


def get_platform_architecture() -> str:
    """
    Determine the platform architecture for Docker images.

    Returns:
        Platform name ("amd64" or "arm64")

    Examples:
        >>> get_platform_architecture() in ["amd64", "arm64"]
        True
    """
    # TODO: Add unit test that exercises the docker info command path when docker is available
    # NOTE(bowei): use the docker info, in case somehow it's different from the host
    # Fall back to platform.machine() if docker is not available
    arch = platform.machine().lower()
    try:
        arch = run_blocking(["docker", "info", "--format", "{{.Architecture}}"]).stdout.strip() or arch
    except ProcessError:
        # Docker not available or command failed, use fallback
        pass

    if arch == "x86_64":
        return "amd64"
    elif arch == "aarch64" or arch == "arm64":
        return "arm64"
    else:
        logger.info(f"Unknown architecture {arch}, defaulting to amd64")
        return "amd64"
