import json
from queue import Queue

from loguru import logger

from sculptor.agents.default.claude_code_sdk.artifact_creation import get_file_artifact_messages
from sculptor.agents.default.claude_code_sdk.utils import get_state_file_contents
from sculptor.agents.default.constants import TOKEN_AND_COST_STATE_FILE
from sculptor.database.models import TaskID
from sculptor.interfaces.agents.agent import Message
from sculptor.interfaces.agents.agent import UpdatedArtifactAgentMessage
from sculptor.interfaces.agents.agent import WarningAgentMessage
from sculptor.interfaces.agents.artifacts import ArtifactType
from sculptor.interfaces.environments.base import Environment


def stream_token_and_cost_info(
    environment: Environment,
    source_branch: str,
    output_message_queue: Queue[Message],
    task_id: TaskID,
) -> None:
    # we should send token and cost info:
    artifact_messages_to_send: list[UpdatedArtifactAgentMessage | WarningAgentMessage] = []
    artifact_messages_to_send.extend(
        get_file_artifact_messages(
            artifact_name=ArtifactType.USAGE,
            environment=environment,
            source_branch=source_branch,
            task_id=task_id,
        )
    )
    for artifact_message in artifact_messages_to_send:
        if artifact_message is not None:
            output_message_queue.put(artifact_message)

    logger.debug("Stream ended")  # process should be done by now, but we'll wait for it to be sure


def update_token_and_cost_state(
    environment: Environment,
    source_branch: str,
    output_message_queue: Queue[Message],
    session_id: str,
    cost_usd: float,
    task_id: TaskID,
) -> None:
    """Update cumulative token count and cost, persisting to state file."""
    cumulative_tokens = 0
    cumulative_cost_usd = cost_usd

    token_state_content = get_state_file_contents(environment, TOKEN_AND_COST_STATE_FILE)
    if token_state_content:
        try:
            token_state = json.loads(token_state_content)
            cumulative_cost_usd += token_state.get("cost_usd", 0.0)
        except json.JSONDecodeError:
            logger.warning("Failed to parse token state file, resetting to zero")

    try:
        session_path = session_id + ".jsonl"
        content = environment.read_file(str(environment.get_claude_jsonl_path() / session_path)).splitlines()
        last_block = content[-1]
        json_block = json.loads(last_block)
        if "message" in json_block:
            info = json_block["message"]
            if "usage" in info:
                tokens = info["usage"]
                cumulative_tokens = (
                    tokens["input_tokens"]
                    + tokens["output_tokens"]
                    + tokens["cache_creation_input_tokens"]
                    + tokens["cache_read_input_tokens"]
                )
    except FileNotFoundError:
        logger.warning("Failed to read claude jsonl file, resetting to zero")
    except json.decoder.JSONDecodeError:
        logger.warning("Failed to parse claude jsonl file, resetting to zero")

    token_state = {"tokens": cumulative_tokens, "cost_usd": cumulative_cost_usd}

    environment.write_file(str(environment.get_state_path() / TOKEN_AND_COST_STATE_FILE), json.dumps(token_state))
    logger.info("Updated token state: {} tokens, ${:.4f}", cumulative_tokens, cumulative_cost_usd)
    stream_token_and_cost_info(
        environment=environment,
        source_branch=source_branch,
        output_message_queue=output_message_queue,
        task_id=task_id,
    )
