"""add project id to images

Revision ID: 748938f056cb
Revises: 3ca4c45c9b2c
Create Date: 2025-09-21 12:52:33.275349

"""

import json
from typing import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "748938f056cb"
down_revision: str | None = "3ca4c45c9b2c"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
    """Upgrade schema."""
    conn = op.get_bind()

    for table_name in ("task", "task_latest"):
        table_updates = []
        result = conn.execute(
            sa.text(f"""
                SELECT rowid, object_id, project_id, current_state
                FROM {table_name}
                WHERE json_extract(current_state, '$.object_type') = 'AgentTaskStateV1'
            """)
        )
        agent_tasks = result.fetchall()
        project_id_by_task_id = {}
        for row_id, task_id, project_id, current_state_json in agent_tasks:
            current_state = json.loads(current_state_json)
            current_state["image"]["project_id"] = str(project_id)
            project_id_by_task_id[task_id] = str(project_id)
            table_updates.append({"rowid": row_id, "current_state": json.dumps(current_state)})

        if table_updates:
            conn.execute(
                sa.text(f"UPDATE {table_name} SET current_state = :current_state WHERE rowid = :rowid"),
                table_updates,
            )

    result = conn.execute(
        sa.text("""
            SELECT DISTINCT task_id
            FROM saved_agent_message
            WHERE json_extract(message, '$.object_type') = 'AgentSnapshotRunnerMessage'
        """)
    )

    tasks_with_snapshots = [row[0] for row in result]
    all_updates = []

    for task_id in tasks_with_snapshots:
        result = conn.execute(
            sa.text("""
                SELECT snapshot_id, message
                FROM saved_agent_message
                WHERE task_id = :task_id
            """),
            {"task_id": task_id},
        )

        messages = result.fetchall()
        for snapshot_id, message in messages:
            snapshot_msg = json.loads(message)
            if snapshot_msg["object_type"] == "AgentSnapshotRunnerMessage":
                snapshot_msg["image"]["project_id"] = project_id_by_task_id[task_id]  # pyre-ignore[61]
                snapshot_msg["image"]["image_id"] = "sha256:" + snapshot_msg["image"]["image_id"]
                all_updates.append({"snapshot_id": snapshot_id, "message": json.dumps(snapshot_msg)})

    if all_updates:
        conn.execute(
            sa.text("UPDATE saved_agent_message SET message = :message WHERE snapshot_id = :snapshot_id"), all_updates
        )

    result = conn.execute(
        sa.text("""
            SELECT DISTINCT task_id
            FROM saved_agent_message
            WHERE json_extract(message, '$.object_type') = 'EnvironmentCreatedRunnerMessage'
        """)
    )

    tasks_with_snapshots = [row[0] for row in result]
    all_updates = []

    for task_id in tasks_with_snapshots:
        result = conn.execute(
            sa.text("""
                SELECT snapshot_id, message
                FROM saved_agent_message
                WHERE task_id = :task_id
            """),
            {"task_id": task_id},
        )

        messages = result.fetchall()
        for snapshot_id, message in messages:
            snapshot_msg = json.loads(message)
            if snapshot_msg["object_type"] == "EnvironmentCreatedRunnerMessage":
                snapshot_msg["environment"]["project_id"] = project_id_by_task_id[task_id]  # pyre-ignore[61]
                all_updates.append({"snapshot_id": snapshot_id, "message": json.dumps(snapshot_msg)})

    if all_updates:
        conn.execute(
            sa.text("UPDATE saved_agent_message SET message = :message WHERE snapshot_id = :snapshot_id"), all_updates
        )


def downgrade() -> None:
    """Downgrade schema."""
    # ### commands auto generated by Alembic - please adjust! ###
    pass
    # ### end Alembic commands ###
