labeler: pass Github Action event data to script

This way, the script can take action based on specific events when that
data is available.
This commit is contained in:
Maxwell G
2023-07-31 01:46:31 +00:00
parent 47a86baeb4
commit 1f1252d356
2 changed files with 41 additions and 3 deletions

View File

@@ -4,11 +4,13 @@
from __future__ import annotations
import dataclasses
import json
import os
from collections.abc import Collection
from contextlib import suppress
from functools import cached_property
from pathlib import Path
from typing import Union
from typing import Any, Union
import github
import github.Auth
@@ -45,16 +47,30 @@ def get_repo(authed: bool = True) -> tuple[github.Github, github.Repository.Repo
return gclient, repo
def get_event_info() -> dict[str, Any]:
event_json = os.environ.get("event_json")
if not event_json:
return {}
with suppress(json.JSONDecodeError):
return json.loads(event_json)
return {}
@dataclasses.dataclass()
class LabelerCtx:
client: github.Github
repo: github.Repository.Repository
dry_run: bool
event_info: dict[str, Any]
@property
def member(self) -> IssueOrPr:
raise NotImplementedError
@property
def event_member(self) -> dict[str, Any]:
raise NotImplementedError
@cached_property
def previously_labeled(self) -> frozenset[str]:
labels: set[str] = set()
@@ -78,6 +94,10 @@ class IssueLabelerCtx(LabelerCtx):
def member(self) -> IssueOrPr:
return self.issue
@property
def event_member(self) -> dict[str, Any]:
return self.event_info.get("issue", {})
@dataclasses.dataclass()
class PRLabelerCtx(LabelerCtx):
@@ -87,6 +107,10 @@ class PRLabelerCtx(LabelerCtx):
def member(self) -> IssueOrPr:
return self.pr
@property
def event_member(self) -> dict[str, Any]:
return self.event_info.get("pull_request", {})
def create_comment(ctx: IssueOrPrCtx, body: str) -> None:
if ctx.dry_run:
@@ -167,7 +191,13 @@ def process_pr(
authed = True
gclient, repo = get_repo(authed=authed)
pr = repo.get_pull(pr_number)
ctx = PRLabelerCtx(client=gclient, repo=repo, pr=pr, dry_run=dry_run)
ctx = PRLabelerCtx(
client=gclient,
repo=repo,
pr=pr,
dry_run=dry_run,
event_info=get_event_info(),
)
if pr.state != "open":
log(ctx, "Refusing to process closed ticket")
return
@@ -187,7 +217,13 @@ def process_issue(
authed = True
gclient, repo = get_repo(authed=authed)
issue = repo.get_issue(issue_number)
ctx = IssueLabelerCtx(client=gclient, repo=repo, issue=issue, dry_run=dry_run)
ctx = IssueLabelerCtx(
client=gclient,
repo=repo,
issue=issue,
dry_run=dry_run,
event_info=get_event_info(),
)
if issue.state != "open":
log(ctx, "Refusing to process closed ticket")
return