Added comments for readability
This commit is contained in:
parent
f1987d2788
commit
285c617871
37
Heimdallr.py
37
Heimdallr.py
|
@ -121,18 +121,47 @@ async def bot_info_command(ctx: InteractionContext):
|
|||
),
|
||||
)
|
||||
|
||||
def set_loglevel(level: str):
|
||||
loglevel = logging.WARNING
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
match level.lower():
|
||||
case "d", "debug":
|
||||
loglevel = logging.DEBUG
|
||||
case "i", "info", "information":
|
||||
loglevel = logging.INFO
|
||||
|
||||
case "w", "warn", "warning":
|
||||
loglevel = logging.WARNING
|
||||
|
||||
case "e", "error":
|
||||
loglevel = logging.ERROR
|
||||
|
||||
case "c", "critical":
|
||||
loglevel = logging.CRITICAL
|
||||
|
||||
case _:
|
||||
loglevel = logging.WARNING
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv() # Load environment variables from .env file
|
||||
set_loglevel(getenv("HEIMDALLR_LOGLEVEL"))
|
||||
|
||||
# Create basic tables
|
||||
GuildSettingsModel.create_table()
|
||||
JoinLeave.create_table()
|
||||
|
||||
load_dotenv()
|
||||
# Load extensions
|
||||
bot.load_extension("commands.admin")
|
||||
bot.load_extension("commands.gatekeep")
|
||||
bot.load_extension("commands.gatekeep")
|
||||
bot.load_extension("commands.quote")
|
||||
bot.load_extension("commands.infractions")
|
||||
bot.load_extension("commands.self_roles")
|
||||
bot.load_extension("commands.polls")
|
||||
bot.load_extension("commands.bot_messages")
|
||||
bot.start(getenv("DISCORD_TOKEN"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -29,10 +29,13 @@ class Admin(Extension):
|
|||
default_member_permissions=Permissions.MANAGE_GUILD,
|
||||
)
|
||||
async def adm_list(self, ctx: InteractionContext) -> None:
|
||||
# Get the guild's settings from the database.
|
||||
guild_settings: Optional[GuildSettingsModel]
|
||||
guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id)
|
||||
joinleave_settings: Optional[GuildSettingsModel]
|
||||
joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id)
|
||||
|
||||
# Create an embed to display settings.
|
||||
embed = Embed(
|
||||
title="Settings for {}".format(ctx.guild.name),
|
||||
fields=[
|
||||
|
|
|
@ -19,7 +19,7 @@ from naff import (
|
|||
|
||||
from database import BotMessages as BotMessagesModel
|
||||
|
||||
|
||||
# Template modal for creating/editing bot messages.
|
||||
message_creation_modal = Modal(
|
||||
custom_id="bot-message-create",
|
||||
title=f"Create a message as the bot",
|
||||
|
@ -44,6 +44,7 @@ class BotMessages(Extension):
|
|||
def __init__(self, client: Client) -> None:
|
||||
self.client = client
|
||||
|
||||
# Create a new bot message.
|
||||
@slash_command(
|
||||
name="bot-message-create",
|
||||
description="Create a message as the bot.",
|
||||
|
@ -52,13 +53,19 @@ class BotMessages(Extension):
|
|||
)
|
||||
async def bot_message_create_command(self, ctx: InteractionContext):
|
||||
|
||||
|
||||
# Respond with the template modal. No values have been set in the modal, as it is
|
||||
# a new message.
|
||||
await ctx.send_modal(message_creation_modal)
|
||||
|
||||
# Wait for the user to submit the modal, and ensure that we are receiving the
|
||||
# correct modal.
|
||||
modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author)
|
||||
if modal_ctx.custom_id != "bot-message-create":
|
||||
return
|
||||
|
||||
# Retrieve the values from the modal.
|
||||
# Ensure that at least one of the fields is filled in, as a message cannot be
|
||||
# empty.
|
||||
embeds_string: str = modal_ctx.responses["embeds"]
|
||||
content_string: str = modal_ctx.responses["content"]
|
||||
if (
|
||||
|
@ -71,6 +78,10 @@ class BotMessages(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
# Attempt to parse the embed(s) JSON into a Python object.
|
||||
# Try loading it as a single or multiple embeds depending on if the object
|
||||
# returned is a list or a dictionary.
|
||||
# If the JSON is invalid, return an error.
|
||||
embed: dict | None = None
|
||||
embeds: list | None = None
|
||||
|
||||
|
@ -88,21 +99,27 @@ class BotMessages(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
# Send the bot message in the channel.
|
||||
msg = await ctx.channel.send(
|
||||
content=content_string if content_string else None,
|
||||
embed=embed,
|
||||
embeds=embeds,
|
||||
)
|
||||
|
||||
# Add an entry of the message in the database.
|
||||
BotMessagesModel.create(
|
||||
guild_id=msg.guild.id,
|
||||
channel_id=msg.channel.id,
|
||||
message_id=msg.id,
|
||||
)
|
||||
|
||||
# Send a confirmation message, as Discord requires us to respond to the interaction.
|
||||
await modal_ctx.send(
|
||||
"Message created!",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
# A context menu to allow moderators to edit a bot message.
|
||||
@context_menu(
|
||||
name="Edit bot message",
|
||||
context_type=CommandTypes.MESSAGE,
|
||||
|
@ -111,6 +128,9 @@ class BotMessages(Extension):
|
|||
)
|
||||
async def edit_bot_message_context_menu(self, ctx: InteractionContext):
|
||||
message: Message = ctx.target
|
||||
|
||||
# Ensure that the target message is from the bot.
|
||||
# If it is not, return an error.
|
||||
if message.author.id != self.client.user.id:
|
||||
await ctx.send(
|
||||
"This is not a bot message.",
|
||||
|
@ -118,30 +138,38 @@ class BotMessages(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
bot_message: BotMessagesModel | None = BotMessagesModel.get(
|
||||
# Retrieve the bot message from the database, if any.
|
||||
bot_message: BotMessagesModel | None = BotMessagesModel.get_or_none(
|
||||
BotMessagesModel.guild_id == message.channel.guild.id,
|
||||
BotMessagesModel.channel_id == message.channel.id,
|
||||
BotMessagesModel.message_id == message.id,
|
||||
)
|
||||
|
||||
# If there is no bot message, return an error.
|
||||
if bot_message is None:
|
||||
await ctx.send(
|
||||
"This is not an editable bot message.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Create a copy of the template modal, and insert the contents of the bot message.
|
||||
modal = deepcopy(message_creation_modal)
|
||||
modal.title = "Edit bot message"
|
||||
modal.custom_id = "bot-message-edit"
|
||||
modal.components[0].value = json.dumps(
|
||||
[e.to_dict() for e in message.embeds] if message.embeds else "",
|
||||
indent=4,
|
||||
)
|
||||
modal.components[1].value = message.content
|
||||
|
||||
# Send the modal to the user
|
||||
await ctx.send_modal(modal)
|
||||
|
||||
# Wait for the user to submit the modal, and ensure that we are receiving the
|
||||
# correct modal.
|
||||
modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author)
|
||||
if modal_ctx.custom_id != "bot-message-create":
|
||||
if modal_ctx.custom_id != "bot-message-edit":
|
||||
return
|
||||
|
||||
embeds_string: str = modal_ctx.responses["embeds"]
|
||||
|
|
|
@ -128,8 +128,10 @@ class Gatekeep(Extension):
|
|||
jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id)
|
||||
await user.add_role(int(gk.gatekeep_approve_role))
|
||||
|
||||
# Check if a welcome channel is set
|
||||
welcome_channel = not jl.message_channel is None
|
||||
|
||||
# If there is no approval message set, inform the issuer privately.
|
||||
if gk.gatekeep_approve_message is None:
|
||||
await ctx.send(
|
||||
f"{user.mention} has been approved.\nNB: No approval message has been sent.",
|
||||
|
@ -137,6 +139,7 @@ class Gatekeep(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
# If there is no welcome channel set, attempt to DM the approval message to the user.
|
||||
if not welcome_channel:
|
||||
await ctx.send(
|
||||
f"{user.mention} has been approved.\nNB: No welcome channel has been set – attempting to DM {user.mention}",
|
||||
|
@ -147,6 +150,7 @@ class Gatekeep(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
# DM the user if the bot fails to retrieve the welcome channel.
|
||||
channel = await ctx.guild.fetch_channel(jl.message_channel)
|
||||
if not channel:
|
||||
await ctx.send(
|
||||
|
@ -158,6 +162,7 @@ class Gatekeep(Extension):
|
|||
)
|
||||
return
|
||||
|
||||
# If none of the above occur, finally send the approval message to the welcome channel.
|
||||
await channel.send(
|
||||
str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild)
|
||||
)
|
||||
|
@ -212,6 +217,11 @@ class Gatekeep(Extension):
|
|||
)
|
||||
await ctx.send(f"{user.mention} has been approved.", ephemeral=True)
|
||||
|
||||
# Allow the use of a reaction to approve a user.
|
||||
# This is mainly for compatibility with NLL.
|
||||
# The permission thingy should probably be reworked, as it currently allows anyone
|
||||
# with the manage roles permission to use this.
|
||||
# TODO: Rewrite this to require a specific role.
|
||||
@listen(events.MessageReactionAdd)
|
||||
async def on_reaction_add(self, reaction: events.MessageReactionAdd):
|
||||
if not reaction.emoji.name in [
|
||||
|
|
|
@ -211,18 +211,18 @@ class Infractions(Extension):
|
|||
|
||||
# TODO: Add this in again when GuildSettings is implemented
|
||||
|
||||
# guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id))
|
||||
# if guild_settings is not None:
|
||||
# if guild_settings.admin_channel is not None:
|
||||
# admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel))
|
||||
# if admin_channel is not None:
|
||||
# await admin_channel.send(embed=Embed(
|
||||
# title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})",
|
||||
# description=f"{reason}",
|
||||
# color=infraction_colour(0x0000FF),
|
||||
# fields=[
|
||||
# ],
|
||||
# ))
|
||||
guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id))
|
||||
if guild_settings is not None:
|
||||
if guild_settings.admin_channel is not None:
|
||||
admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel))
|
||||
if admin_channel is not None:
|
||||
await admin_channel.send(embed=Embed(
|
||||
title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})",
|
||||
description=f"{reason}",
|
||||
color=infraction_colour(0x0000FF),
|
||||
fields=[
|
||||
],
|
||||
))
|
||||
|
||||
if not silent and warning_msg is None:
|
||||
await ctx.send(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import Dict, List, Tuple
|
||||
from naff import (
|
||||
Client,
|
||||
|
@ -34,12 +35,24 @@ from peewee import fn
|
|||
PollOptions = List[Tuple[str | None, str]]
|
||||
|
||||
|
||||
def datetime_to_discord_time(dt: datetime) -> str:
|
||||
def datetime_to_discord_relative_time(dt: datetime) -> str:
|
||||
"""Create a Discord relative time text from a datetime object."""
|
||||
t = dt.strftime("%s")
|
||||
return f"<t:{int(t)}:R>"
|
||||
|
||||
|
||||
def generate_bar(num: int, total: int, length: int = 10) -> str:
|
||||
"""Create a bar graph from a number and a total.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num : int
|
||||
The current amount.
|
||||
total : int
|
||||
The total amount.
|
||||
length : int
|
||||
The amount of characters to use for the bar.
|
||||
"""
|
||||
full_char = "\u2593"
|
||||
empty_char = "\u2591"
|
||||
|
||||
|
@ -63,11 +76,27 @@ def generate_poll_embed(
|
|||
multiple_choice: bool = False,
|
||||
expires: datetime = None,
|
||||
) -> Embed:
|
||||
"""Create a poll embed from a title, options and votes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
title : str
|
||||
The title of the poll.
|
||||
options : PollOptions
|
||||
The options of the poll.
|
||||
votes : List[int]
|
||||
The votes of the poll, in the same order as the options.
|
||||
multiple_choice : bool, optional
|
||||
Whether the poll is multiple choice. Defaults to False.
|
||||
expires : datetime, optional
|
||||
The time at which the poll expires. Defaults to None.
|
||||
"""
|
||||
data = []
|
||||
# \u2022 is a bullet point
|
||||
if multiple_choice:
|
||||
data.append("\u2022 Multiple choice")
|
||||
if expires:
|
||||
data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}")
|
||||
data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}")
|
||||
|
||||
embed = Embed(
|
||||
title=title,
|
||||
|
@ -75,6 +104,7 @@ def generate_poll_embed(
|
|||
)
|
||||
sum_votes = sum(votes)
|
||||
|
||||
# Add a field for each option, with the vote count and a bar graph showing the percentage.
|
||||
for i, (emoji, option) in enumerate(options):
|
||||
embed.add_field(
|
||||
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
|
||||
|
@ -178,22 +208,31 @@ class Polls(Extension):
|
|||
],
|
||||
)
|
||||
|
||||
# Display the modal and wait for the user to submit.
|
||||
await ctx.send_modal(modal)
|
||||
modal_ctx: ModalContext = await self.client.wait_for_modal(
|
||||
modal=modal, author=ctx.author
|
||||
)
|
||||
|
||||
# If the user set a duration for the poll, create a datetime for the time in the
|
||||
# future when the poll expires.
|
||||
duration: datetime | None = (
|
||||
(datetime.now() + timedelta(minutes=duration)) if duration else None
|
||||
)
|
||||
|
||||
title = modal_ctx.responses["title"]
|
||||
options: PollOptions = []
|
||||
|
||||
# Get each option from the poll options.
|
||||
# We're stripping the first occurance of a dash, as it otherwise would be included.
|
||||
for i, option in enumerate(
|
||||
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
|
||||
):
|
||||
# If the option is empty, skip it.
|
||||
if option == "":
|
||||
continue
|
||||
|
||||
# Check if the option contains an emoji.
|
||||
parts = option.split(":", 1)
|
||||
if len(parts) == 1:
|
||||
options.append((None, parts[0].strip()))
|
||||
|
@ -217,6 +256,7 @@ class Polls(Extension):
|
|||
expires=duration,
|
||||
)
|
||||
|
||||
# Create vote buttons for each option.
|
||||
buttons: List[Button] = []
|
||||
for i, option in enumerate(options):
|
||||
try:
|
||||
|
@ -231,6 +271,7 @@ class Polls(Extension):
|
|||
)
|
||||
)
|
||||
|
||||
# Create a button to allow locking the poll.
|
||||
buttons.append(
|
||||
Button(
|
||||
emoji="🔒",
|
||||
|
@ -251,9 +292,11 @@ class Polls(Extension):
|
|||
embed=embed,
|
||||
components=spread_to_rows(*buttons),
|
||||
)
|
||||
# Naive error handling.
|
||||
except HTTPException as e:
|
||||
logging.error(f"Error sending poll: {e}")
|
||||
await modal_ctx.send(
|
||||
# TODO: This should probably also include a sample for poll options.
|
||||
"Error during poll creation. NB: You can not use server-specific emojis",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
@ -268,7 +311,9 @@ class Polls(Extension):
|
|||
ctx = button.context
|
||||
await ctx.defer(ephemeral=True)
|
||||
|
||||
# Ensure that the pressed button is a vote button.
|
||||
if ctx.custom_id.startswith("poll-vote:"):
|
||||
# Get the poll ID and the option index.
|
||||
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
|
||||
|
||||
poll_entry: PollsModel | None = PollsModel.get_or_none(
|
||||
|
@ -285,9 +330,12 @@ class Polls(Extension):
|
|||
PollVotesModel.poll_id == poll_id,
|
||||
PollVotesModel.user_id == ctx.author.id,
|
||||
)
|
||||
# If the user somehow already has more than one vote, delete them.
|
||||
# This should never happen, but just in case.
|
||||
# Then, add the new vote.
|
||||
if votes_q.count() > 1:
|
||||
for vote in votes_q:
|
||||
vote.delete().execute()
|
||||
vote.delete_instance()
|
||||
|
||||
PollVotesModel.create(
|
||||
poll_id=poll_id,
|
||||
|
@ -295,13 +343,16 @@ class Polls(Extension):
|
|||
option=option_num,
|
||||
)
|
||||
elif votes_q.count() == 1:
|
||||
# If the vote is the current vote, delete it.
|
||||
if int(votes_q[0].option) == int(option_num):
|
||||
votes_q[0].delete_instance()
|
||||
await ctx.send("You have removed your vote.")
|
||||
# If it's not the current vote, change the vote to the new one.
|
||||
else:
|
||||
votes_q[0].option = option_num
|
||||
votes_q[0].save()
|
||||
await ctx.send("You have changed your vote.")
|
||||
#If the user has no votes, add a new vote.
|
||||
else:
|
||||
PollVotesModel.create(
|
||||
poll_id=poll_id,
|
||||
|
@ -309,12 +360,15 @@ class Polls(Extension):
|
|||
option=option_num,
|
||||
)
|
||||
await ctx.send("You have voted.")
|
||||
|
||||
# If the poll is multiple choice
|
||||
else:
|
||||
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
|
||||
PollVotesModel.poll_id == poll_id,
|
||||
PollVotesModel.user_id == ctx.author.id,
|
||||
)
|
||||
|
||||
# If the user has already voted for this option, remove their vote.
|
||||
exists = False
|
||||
for vote in votes_q:
|
||||
if int(vote.option) == (option_num):
|
||||
|
@ -323,6 +377,7 @@ class Polls(Extension):
|
|||
await ctx.send("You have removed your vote.")
|
||||
break
|
||||
|
||||
# If the user has not voted for this option, add a new vote.
|
||||
if not exists:
|
||||
PollVotesModel.create(
|
||||
poll_id=poll_id,
|
||||
|
@ -360,8 +415,10 @@ class Polls(Extension):
|
|||
expires=poll_entry.expires,
|
||||
)
|
||||
|
||||
# Edit the message with the updated information.
|
||||
await ctx.message.edit(embed=embed)
|
||||
|
||||
# If the "lock poll" button is pressed, lock the poll.
|
||||
elif ctx.custom_id.startswith("poll-lock:"):
|
||||
poll_id = ctx.custom_id.split(":", 1)[1]
|
||||
|
||||
|
@ -372,18 +429,22 @@ class Polls(Extension):
|
|||
await ctx.send("That poll doesn't exist.")
|
||||
return
|
||||
|
||||
# Ensure that the user is the poll creator, or can manage messages.
|
||||
if not ctx.author.id == int(
|
||||
poll_entry.author_id
|
||||
) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES):
|
||||
await ctx.send("You don't have permission to lock that poll.")
|
||||
return
|
||||
|
||||
# Set the poll to be expired to lock it.
|
||||
poll_entry.expires = datetime.now() - timedelta(minutes=1)
|
||||
poll_entry.save()
|
||||
|
||||
# Force the "poll expiry check" task to run.
|
||||
await self.poll_expiry_check()
|
||||
await ctx.send("Poll locked.")
|
||||
|
||||
# A task that runs each minute to check for expired polls.
|
||||
@Task.create(IntervalTrigger(minutes=1))
|
||||
async def poll_expiry_check(self):
|
||||
logging.info("Checking for expired polls.")
|
||||
|
@ -402,6 +463,10 @@ class Polls(Extension):
|
|||
continue
|
||||
|
||||
await message.edit(components=[])
|
||||
# Delete associated database entries, as they will no longer be updated.
|
||||
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute()
|
||||
poll_entry.delete_instance()
|
||||
|
||||
|
||||
|
||||
def setup(client: Client):
|
||||
|
|
Loading…
Reference in New Issue