Added comments for readability

This commit is contained in:
Vegard Berg 2022-08-11 01:39:15 +02:00
parent f1987d2788
commit 285c617871
6 changed files with 158 additions and 23 deletions

View File

@ -121,13 +121,39 @@ async def bot_info_command(ctx: InteractionContext):
), ),
) )
def set_loglevel(level: str):
loglevel = logging.WARNING
if __name__ == "__main__": match level.lower():
logging.basicConfig(level=logging.INFO) case "d", "debug":
loglevel = logging.DEBUG
case "i", "info", "information":
loglevel = logging.INFO
case "w", "warn", "warning":
loglevel = logging.WARNING
case "e", "error":
loglevel = logging.ERROR
case "c", "critical":
loglevel = logging.CRITICAL
case _:
loglevel = logging.WARNING
logging.basicConfig(level=loglevel)
def main():
load_dotenv() # Load environment variables from .env file
set_loglevel(getenv("HEIMDALLR_LOGLEVEL"))
# Create basic tables
GuildSettingsModel.create_table() GuildSettingsModel.create_table()
JoinLeave.create_table() JoinLeave.create_table()
load_dotenv() # Load extensions
bot.load_extension("commands.admin") bot.load_extension("commands.admin")
bot.load_extension("commands.gatekeep") bot.load_extension("commands.gatekeep")
bot.load_extension("commands.quote") bot.load_extension("commands.quote")
@ -136,3 +162,6 @@ if __name__ == "__main__":
bot.load_extension("commands.polls") bot.load_extension("commands.polls")
bot.load_extension("commands.bot_messages") bot.load_extension("commands.bot_messages")
bot.start(getenv("DISCORD_TOKEN")) bot.start(getenv("DISCORD_TOKEN"))
if __name__ == "__main__":
main()

View File

@ -29,10 +29,13 @@ class Admin(Extension):
default_member_permissions=Permissions.MANAGE_GUILD, default_member_permissions=Permissions.MANAGE_GUILD,
) )
async def adm_list(self, ctx: InteractionContext) -> None: async def adm_list(self, ctx: InteractionContext) -> None:
# Get the guild's settings from the database.
guild_settings: Optional[GuildSettingsModel] guild_settings: Optional[GuildSettingsModel]
guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id) guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id)
joinleave_settings: Optional[GuildSettingsModel] joinleave_settings: Optional[GuildSettingsModel]
joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id) joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id)
# Create an embed to display settings.
embed = Embed( embed = Embed(
title="Settings for {}".format(ctx.guild.name), title="Settings for {}".format(ctx.guild.name),
fields=[ fields=[

View File

@ -19,7 +19,7 @@ from naff import (
from database import BotMessages as BotMessagesModel from database import BotMessages as BotMessagesModel
# Template modal for creating/editing bot messages.
message_creation_modal = Modal( message_creation_modal = Modal(
custom_id="bot-message-create", custom_id="bot-message-create",
title=f"Create a message as the bot", title=f"Create a message as the bot",
@ -44,6 +44,7 @@ class BotMessages(Extension):
def __init__(self, client: Client) -> None: def __init__(self, client: Client) -> None:
self.client = client self.client = client
# Create a new bot message.
@slash_command( @slash_command(
name="bot-message-create", name="bot-message-create",
description="Create a message as the bot.", description="Create a message as the bot.",
@ -52,13 +53,19 @@ class BotMessages(Extension):
) )
async def bot_message_create_command(self, ctx: InteractionContext): async def bot_message_create_command(self, ctx: InteractionContext):
# Respond with the template modal. No values have been set in the modal, as it is
# a new message.
await ctx.send_modal(message_creation_modal) await ctx.send_modal(message_creation_modal)
# Wait for the user to submit the modal, and ensure that we are receiving the
# correct modal.
modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author) modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author)
if modal_ctx.custom_id != "bot-message-create": if modal_ctx.custom_id != "bot-message-create":
return return
# Retrieve the values from the modal.
# Ensure that at least one of the fields is filled in, as a message cannot be
# empty.
embeds_string: str = modal_ctx.responses["embeds"] embeds_string: str = modal_ctx.responses["embeds"]
content_string: str = modal_ctx.responses["content"] content_string: str = modal_ctx.responses["content"]
if ( if (
@ -71,6 +78,10 @@ class BotMessages(Extension):
) )
return return
# Attempt to parse the embed(s) JSON into a Python object.
# Try loading it as a single or multiple embeds depending on if the object
# returned is a list or a dictionary.
# If the JSON is invalid, return an error.
embed: dict | None = None embed: dict | None = None
embeds: list | None = None embeds: list | None = None
@ -88,21 +99,27 @@ class BotMessages(Extension):
) )
return return
# Send the bot message in the channel.
msg = await ctx.channel.send( msg = await ctx.channel.send(
content=content_string if content_string else None, content=content_string if content_string else None,
embed=embed, embed=embed,
embeds=embeds, embeds=embeds,
) )
# Add an entry of the message in the database.
BotMessagesModel.create( BotMessagesModel.create(
guild_id=msg.guild.id, guild_id=msg.guild.id,
channel_id=msg.channel.id, channel_id=msg.channel.id,
message_id=msg.id, message_id=msg.id,
) )
# Send a confirmation message, as Discord requires us to respond to the interaction.
await modal_ctx.send( await modal_ctx.send(
"Message created!", "Message created!",
ephemeral=True, ephemeral=True,
) )
# A context menu to allow moderators to edit a bot message.
@context_menu( @context_menu(
name="Edit bot message", name="Edit bot message",
context_type=CommandTypes.MESSAGE, context_type=CommandTypes.MESSAGE,
@ -111,6 +128,9 @@ class BotMessages(Extension):
) )
async def edit_bot_message_context_menu(self, ctx: InteractionContext): async def edit_bot_message_context_menu(self, ctx: InteractionContext):
message: Message = ctx.target message: Message = ctx.target
# Ensure that the target message is from the bot.
# If it is not, return an error.
if message.author.id != self.client.user.id: if message.author.id != self.client.user.id:
await ctx.send( await ctx.send(
"This is not a bot message.", "This is not a bot message.",
@ -118,30 +138,38 @@ class BotMessages(Extension):
) )
return return
bot_message: BotMessagesModel | None = BotMessagesModel.get( # Retrieve the bot message from the database, if any.
bot_message: BotMessagesModel | None = BotMessagesModel.get_or_none(
BotMessagesModel.guild_id == message.channel.guild.id, BotMessagesModel.guild_id == message.channel.guild.id,
BotMessagesModel.channel_id == message.channel.id, BotMessagesModel.channel_id == message.channel.id,
BotMessagesModel.message_id == message.id, BotMessagesModel.message_id == message.id,
) )
# If there is no bot message, return an error.
if bot_message is None: if bot_message is None:
await ctx.send( await ctx.send(
"This is not an editable bot message.", "This is not an editable bot message.",
ephemeral=True, ephemeral=True,
) )
return return
# Create a copy of the template modal, and insert the contents of the bot message.
modal = deepcopy(message_creation_modal) modal = deepcopy(message_creation_modal)
modal.title = "Edit bot message" modal.title = "Edit bot message"
modal.custom_id = "bot-message-edit"
modal.components[0].value = json.dumps( modal.components[0].value = json.dumps(
[e.to_dict() for e in message.embeds] if message.embeds else "", [e.to_dict() for e in message.embeds] if message.embeds else "",
indent=4, indent=4,
) )
modal.components[1].value = message.content modal.components[1].value = message.content
# Send the modal to the user
await ctx.send_modal(modal) await ctx.send_modal(modal)
# Wait for the user to submit the modal, and ensure that we are receiving the
# correct modal.
modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author) modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author)
if modal_ctx.custom_id != "bot-message-create": if modal_ctx.custom_id != "bot-message-edit":
return return
embeds_string: str = modal_ctx.responses["embeds"] embeds_string: str = modal_ctx.responses["embeds"]

View File

@ -128,8 +128,10 @@ class Gatekeep(Extension):
jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id) jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id)
await user.add_role(int(gk.gatekeep_approve_role)) await user.add_role(int(gk.gatekeep_approve_role))
# Check if a welcome channel is set
welcome_channel = not jl.message_channel is None welcome_channel = not jl.message_channel is None
# If there is no approval message set, inform the issuer privately.
if gk.gatekeep_approve_message is None: if gk.gatekeep_approve_message is None:
await ctx.send( await ctx.send(
f"{user.mention} has been approved.\nNB: No approval message has been sent.", f"{user.mention} has been approved.\nNB: No approval message has been sent.",
@ -137,6 +139,7 @@ class Gatekeep(Extension):
) )
return return
# If there is no welcome channel set, attempt to DM the approval message to the user.
if not welcome_channel: if not welcome_channel:
await ctx.send( await ctx.send(
f"{user.mention} has been approved.\nNB: No welcome channel has been set attempting to DM {user.mention}", f"{user.mention} has been approved.\nNB: No welcome channel has been set attempting to DM {user.mention}",
@ -147,6 +150,7 @@ class Gatekeep(Extension):
) )
return return
# DM the user if the bot fails to retrieve the welcome channel.
channel = await ctx.guild.fetch_channel(jl.message_channel) channel = await ctx.guild.fetch_channel(jl.message_channel)
if not channel: if not channel:
await ctx.send( await ctx.send(
@ -158,6 +162,7 @@ class Gatekeep(Extension):
) )
return return
# If none of the above occur, finally send the approval message to the welcome channel.
await channel.send( await channel.send(
str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild) str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild)
) )
@ -212,6 +217,11 @@ class Gatekeep(Extension):
) )
await ctx.send(f"{user.mention} has been approved.", ephemeral=True) await ctx.send(f"{user.mention} has been approved.", ephemeral=True)
# Allow the use of a reaction to approve a user.
# This is mainly for compatibility with NLL.
# The permission thingy should probably be reworked, as it currently allows anyone
# with the manage roles permission to use this.
# TODO: Rewrite this to require a specific role.
@listen(events.MessageReactionAdd) @listen(events.MessageReactionAdd)
async def on_reaction_add(self, reaction: events.MessageReactionAdd): async def on_reaction_add(self, reaction: events.MessageReactionAdd):
if not reaction.emoji.name in [ if not reaction.emoji.name in [

View File

@ -211,18 +211,18 @@ class Infractions(Extension):
# TODO: Add this in again when GuildSettings is implemented # TODO: Add this in again when GuildSettings is implemented
# guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id)) guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id))
# if guild_settings is not None: if guild_settings is not None:
# if guild_settings.admin_channel is not None: if guild_settings.admin_channel is not None:
# admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel)) admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel))
# if admin_channel is not None: if admin_channel is not None:
# await admin_channel.send(embed=Embed( await admin_channel.send(embed=Embed(
# title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})", title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})",
# description=f"{reason}", description=f"{reason}",
# color=infraction_colour(0x0000FF), color=infraction_colour(0x0000FF),
# fields=[ fields=[
# ], ],
# )) ))
if not silent and warning_msg is None: if not silent and warning_msg is None:
await ctx.send( await ctx.send(

View File

@ -1,6 +1,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
import json import json
from time import sleep
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from naff import ( from naff import (
Client, Client,
@ -34,12 +35,24 @@ from peewee import fn
PollOptions = List[Tuple[str | None, str]] PollOptions = List[Tuple[str | None, str]]
def datetime_to_discord_time(dt: datetime) -> str: def datetime_to_discord_relative_time(dt: datetime) -> str:
"""Create a Discord relative time text from a datetime object."""
t = dt.strftime("%s") t = dt.strftime("%s")
return f"<t:{int(t)}:R>" return f"<t:{int(t)}:R>"
def generate_bar(num: int, total: int, length: int = 10) -> str: def generate_bar(num: int, total: int, length: int = 10) -> str:
"""Create a bar graph from a number and a total.
Parameters
----------
num : int
The current amount.
total : int
The total amount.
length : int
The amount of characters to use for the bar.
"""
full_char = "\u2593" full_char = "\u2593"
empty_char = "\u2591" empty_char = "\u2591"
@ -63,11 +76,27 @@ def generate_poll_embed(
multiple_choice: bool = False, multiple_choice: bool = False,
expires: datetime = None, expires: datetime = None,
) -> Embed: ) -> Embed:
"""Create a poll embed from a title, options and votes.
Parameters
----------
title : str
The title of the poll.
options : PollOptions
The options of the poll.
votes : List[int]
The votes of the poll, in the same order as the options.
multiple_choice : bool, optional
Whether the poll is multiple choice. Defaults to False.
expires : datetime, optional
The time at which the poll expires. Defaults to None.
"""
data = [] data = []
# \u2022 is a bullet point
if multiple_choice: if multiple_choice:
data.append("\u2022 Multiple choice") data.append("\u2022 Multiple choice")
if expires: if expires:
data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}") data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}")
embed = Embed( embed = Embed(
title=title, title=title,
@ -75,6 +104,7 @@ def generate_poll_embed(
) )
sum_votes = sum(votes) sum_votes = sum(votes)
# Add a field for each option, with the vote count and a bar graph showing the percentage.
for i, (emoji, option) in enumerate(options): for i, (emoji, option) in enumerate(options):
embed.add_field( embed.add_field(
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**", name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
@ -178,22 +208,31 @@ class Polls(Extension):
], ],
) )
# Display the modal and wait for the user to submit.
await ctx.send_modal(modal) await ctx.send_modal(modal)
modal_ctx: ModalContext = await self.client.wait_for_modal( modal_ctx: ModalContext = await self.client.wait_for_modal(
modal=modal, author=ctx.author modal=modal, author=ctx.author
) )
# If the user set a duration for the poll, create a datetime for the time in the
# future when the poll expires.
duration: datetime | None = ( duration: datetime | None = (
(datetime.now() + timedelta(minutes=duration)) if duration else None (datetime.now() + timedelta(minutes=duration)) if duration else None
) )
title = modal_ctx.responses["title"] title = modal_ctx.responses["title"]
options: PollOptions = [] options: PollOptions = []
# Get each option from the poll options.
# We're stripping the first occurance of a dash, as it otherwise would be included.
for i, option in enumerate( for i, option in enumerate(
modal_ctx.responses["options"].replace("-", "", 1).split("\n-") modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
): ):
# If the option is empty, skip it.
if option == "": if option == "":
continue continue
# Check if the option contains an emoji.
parts = option.split(":", 1) parts = option.split(":", 1)
if len(parts) == 1: if len(parts) == 1:
options.append((None, parts[0].strip())) options.append((None, parts[0].strip()))
@ -217,6 +256,7 @@ class Polls(Extension):
expires=duration, expires=duration,
) )
# Create vote buttons for each option.
buttons: List[Button] = [] buttons: List[Button] = []
for i, option in enumerate(options): for i, option in enumerate(options):
try: try:
@ -231,6 +271,7 @@ class Polls(Extension):
) )
) )
# Create a button to allow locking the poll.
buttons.append( buttons.append(
Button( Button(
emoji="🔒", emoji="🔒",
@ -251,9 +292,11 @@ class Polls(Extension):
embed=embed, embed=embed,
components=spread_to_rows(*buttons), components=spread_to_rows(*buttons),
) )
# Naive error handling.
except HTTPException as e: except HTTPException as e:
logging.error(f"Error sending poll: {e}") logging.error(f"Error sending poll: {e}")
await modal_ctx.send( await modal_ctx.send(
# TODO: This should probably also include a sample for poll options.
"Error during poll creation. NB: You can not use server-specific emojis", "Error during poll creation. NB: You can not use server-specific emojis",
ephemeral=True, ephemeral=True,
) )
@ -268,7 +311,9 @@ class Polls(Extension):
ctx = button.context ctx = button.context
await ctx.defer(ephemeral=True) await ctx.defer(ephemeral=True)
# Ensure that the pressed button is a vote button.
if ctx.custom_id.startswith("poll-vote:"): if ctx.custom_id.startswith("poll-vote:"):
# Get the poll ID and the option index.
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1) poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
poll_entry: PollsModel | None = PollsModel.get_or_none( poll_entry: PollsModel | None = PollsModel.get_or_none(
@ -285,9 +330,12 @@ class Polls(Extension):
PollVotesModel.poll_id == poll_id, PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id, PollVotesModel.user_id == ctx.author.id,
) )
# If the user somehow already has more than one vote, delete them.
# This should never happen, but just in case.
# Then, add the new vote.
if votes_q.count() > 1: if votes_q.count() > 1:
for vote in votes_q: for vote in votes_q:
vote.delete().execute() vote.delete_instance()
PollVotesModel.create( PollVotesModel.create(
poll_id=poll_id, poll_id=poll_id,
@ -295,13 +343,16 @@ class Polls(Extension):
option=option_num, option=option_num,
) )
elif votes_q.count() == 1: elif votes_q.count() == 1:
# If the vote is the current vote, delete it.
if int(votes_q[0].option) == int(option_num): if int(votes_q[0].option) == int(option_num):
votes_q[0].delete_instance() votes_q[0].delete_instance()
await ctx.send("You have removed your vote.") await ctx.send("You have removed your vote.")
# If it's not the current vote, change the vote to the new one.
else: else:
votes_q[0].option = option_num votes_q[0].option = option_num
votes_q[0].save() votes_q[0].save()
await ctx.send("You have changed your vote.") await ctx.send("You have changed your vote.")
#If the user has no votes, add a new vote.
else: else:
PollVotesModel.create( PollVotesModel.create(
poll_id=poll_id, poll_id=poll_id,
@ -309,12 +360,15 @@ class Polls(Extension):
option=option_num, option=option_num,
) )
await ctx.send("You have voted.") await ctx.send("You have voted.")
# If the poll is multiple choice
else: else:
votes_q: List[PollVotesModel] = PollVotesModel.select().where( votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id, PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id, PollVotesModel.user_id == ctx.author.id,
) )
# If the user has already voted for this option, remove their vote.
exists = False exists = False
for vote in votes_q: for vote in votes_q:
if int(vote.option) == (option_num): if int(vote.option) == (option_num):
@ -323,6 +377,7 @@ class Polls(Extension):
await ctx.send("You have removed your vote.") await ctx.send("You have removed your vote.")
break break
# If the user has not voted for this option, add a new vote.
if not exists: if not exists:
PollVotesModel.create( PollVotesModel.create(
poll_id=poll_id, poll_id=poll_id,
@ -360,8 +415,10 @@ class Polls(Extension):
expires=poll_entry.expires, expires=poll_entry.expires,
) )
# Edit the message with the updated information.
await ctx.message.edit(embed=embed) await ctx.message.edit(embed=embed)
# If the "lock poll" button is pressed, lock the poll.
elif ctx.custom_id.startswith("poll-lock:"): elif ctx.custom_id.startswith("poll-lock:"):
poll_id = ctx.custom_id.split(":", 1)[1] poll_id = ctx.custom_id.split(":", 1)[1]
@ -372,18 +429,22 @@ class Polls(Extension):
await ctx.send("That poll doesn't exist.") await ctx.send("That poll doesn't exist.")
return return
# Ensure that the user is the poll creator, or can manage messages.
if not ctx.author.id == int( if not ctx.author.id == int(
poll_entry.author_id poll_entry.author_id
) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES): ) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES):
await ctx.send("You don't have permission to lock that poll.") await ctx.send("You don't have permission to lock that poll.")
return return
# Set the poll to be expired to lock it.
poll_entry.expires = datetime.now() - timedelta(minutes=1) poll_entry.expires = datetime.now() - timedelta(minutes=1)
poll_entry.save() poll_entry.save()
# Force the "poll expiry check" task to run.
await self.poll_expiry_check() await self.poll_expiry_check()
await ctx.send("Poll locked.") await ctx.send("Poll locked.")
# A task that runs each minute to check for expired polls.
@Task.create(IntervalTrigger(minutes=1)) @Task.create(IntervalTrigger(minutes=1))
async def poll_expiry_check(self): async def poll_expiry_check(self):
logging.info("Checking for expired polls.") logging.info("Checking for expired polls.")
@ -402,6 +463,10 @@ class Polls(Extension):
continue continue
await message.edit(components=[]) await message.edit(components=[])
# Delete associated database entries, as they will no longer be updated.
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute()
poll_entry.delete_instance()
def setup(client: Client): def setup(client: Client):