diff --git a/Heimdallr.py b/Heimdallr.py index 7e54dbe..186cfa6 100644 --- a/Heimdallr.py +++ b/Heimdallr.py @@ -121,18 +121,47 @@ async def bot_info_command(ctx: InteractionContext): ), ) +def set_loglevel(level: str): + loglevel = logging.WARNING -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + match level.lower(): + case "d", "debug": + loglevel = logging.DEBUG + case "i", "info", "information": + loglevel = logging.INFO + + case "w", "warn", "warning": + loglevel = logging.WARNING + + case "e", "error": + loglevel = logging.ERROR + + case "c", "critical": + loglevel = logging.CRITICAL + + case _: + loglevel = logging.WARNING + + logging.basicConfig(level=loglevel) + + +def main(): + load_dotenv() # Load environment variables from .env file + set_loglevel(getenv("HEIMDALLR_LOGLEVEL")) + + # Create basic tables GuildSettingsModel.create_table() JoinLeave.create_table() - load_dotenv() + # Load extensions bot.load_extension("commands.admin") - bot.load_extension("commands.gatekeep") + bot.load_extension("commands.gatekeep") bot.load_extension("commands.quote") bot.load_extension("commands.infractions") bot.load_extension("commands.self_roles") bot.load_extension("commands.polls") bot.load_extension("commands.bot_messages") bot.start(getenv("DISCORD_TOKEN")) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/commands/admin.py b/commands/admin.py index 349caf8..f201c72 100644 --- a/commands/admin.py +++ b/commands/admin.py @@ -29,10 +29,13 @@ class Admin(Extension): default_member_permissions=Permissions.MANAGE_GUILD, ) async def adm_list(self, ctx: InteractionContext) -> None: + # Get the guild's settings from the database. guild_settings: Optional[GuildSettingsModel] guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id) joinleave_settings: Optional[GuildSettingsModel] joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id) + + # Create an embed to display settings. embed = Embed( title="Settings for {}".format(ctx.guild.name), fields=[ diff --git a/commands/bot_messages.py b/commands/bot_messages.py index d036df8..f63a660 100644 --- a/commands/bot_messages.py +++ b/commands/bot_messages.py @@ -19,7 +19,7 @@ from naff import ( from database import BotMessages as BotMessagesModel - +# Template modal for creating/editing bot messages. message_creation_modal = Modal( custom_id="bot-message-create", title=f"Create a message as the bot", @@ -44,6 +44,7 @@ class BotMessages(Extension): def __init__(self, client: Client) -> None: self.client = client + # Create a new bot message. @slash_command( name="bot-message-create", description="Create a message as the bot.", @@ -52,13 +53,19 @@ class BotMessages(Extension): ) async def bot_message_create_command(self, ctx: InteractionContext): - + # Respond with the template modal. No values have been set in the modal, as it is + # a new message. await ctx.send_modal(message_creation_modal) + # Wait for the user to submit the modal, and ensure that we are receiving the + # correct modal. modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author) if modal_ctx.custom_id != "bot-message-create": return + # Retrieve the values from the modal. + # Ensure that at least one of the fields is filled in, as a message cannot be + # empty. embeds_string: str = modal_ctx.responses["embeds"] content_string: str = modal_ctx.responses["content"] if ( @@ -71,6 +78,10 @@ class BotMessages(Extension): ) return + # Attempt to parse the embed(s) JSON into a Python object. + # Try loading it as a single or multiple embeds depending on if the object + # returned is a list or a dictionary. + # If the JSON is invalid, return an error. embed: dict | None = None embeds: list | None = None @@ -88,21 +99,27 @@ class BotMessages(Extension): ) return + # Send the bot message in the channel. msg = await ctx.channel.send( content=content_string if content_string else None, embed=embed, embeds=embeds, ) + + # Add an entry of the message in the database. BotMessagesModel.create( guild_id=msg.guild.id, channel_id=msg.channel.id, message_id=msg.id, ) + + # Send a confirmation message, as Discord requires us to respond to the interaction. await modal_ctx.send( "Message created!", ephemeral=True, ) + # A context menu to allow moderators to edit a bot message. @context_menu( name="Edit bot message", context_type=CommandTypes.MESSAGE, @@ -111,6 +128,9 @@ class BotMessages(Extension): ) async def edit_bot_message_context_menu(self, ctx: InteractionContext): message: Message = ctx.target + + # Ensure that the target message is from the bot. + # If it is not, return an error. if message.author.id != self.client.user.id: await ctx.send( "This is not a bot message.", @@ -118,30 +138,38 @@ class BotMessages(Extension): ) return - bot_message: BotMessagesModel | None = BotMessagesModel.get( + # Retrieve the bot message from the database, if any. + bot_message: BotMessagesModel | None = BotMessagesModel.get_or_none( BotMessagesModel.guild_id == message.channel.guild.id, BotMessagesModel.channel_id == message.channel.id, BotMessagesModel.message_id == message.id, ) + # If there is no bot message, return an error. if bot_message is None: await ctx.send( "This is not an editable bot message.", ephemeral=True, ) return + + # Create a copy of the template modal, and insert the contents of the bot message. modal = deepcopy(message_creation_modal) modal.title = "Edit bot message" + modal.custom_id = "bot-message-edit" modal.components[0].value = json.dumps( [e.to_dict() for e in message.embeds] if message.embeds else "", indent=4, ) modal.components[1].value = message.content + # Send the modal to the user await ctx.send_modal(modal) + # Wait for the user to submit the modal, and ensure that we are receiving the + # correct modal. modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author) - if modal_ctx.custom_id != "bot-message-create": + if modal_ctx.custom_id != "bot-message-edit": return embeds_string: str = modal_ctx.responses["embeds"] diff --git a/commands/gatekeep.py b/commands/gatekeep.py index b99613f..44bedef 100644 --- a/commands/gatekeep.py +++ b/commands/gatekeep.py @@ -128,8 +128,10 @@ class Gatekeep(Extension): jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id) await user.add_role(int(gk.gatekeep_approve_role)) + # Check if a welcome channel is set welcome_channel = not jl.message_channel is None + # If there is no approval message set, inform the issuer privately. if gk.gatekeep_approve_message is None: await ctx.send( f"{user.mention} has been approved.\nNB: No approval message has been sent.", @@ -137,6 +139,7 @@ class Gatekeep(Extension): ) return + # If there is no welcome channel set, attempt to DM the approval message to the user. if not welcome_channel: await ctx.send( f"{user.mention} has been approved.\nNB: No welcome channel has been set – attempting to DM {user.mention}", @@ -147,6 +150,7 @@ class Gatekeep(Extension): ) return + # DM the user if the bot fails to retrieve the welcome channel. channel = await ctx.guild.fetch_channel(jl.message_channel) if not channel: await ctx.send( @@ -158,6 +162,7 @@ class Gatekeep(Extension): ) return + # If none of the above occur, finally send the approval message to the welcome channel. await channel.send( str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild) ) @@ -212,6 +217,11 @@ class Gatekeep(Extension): ) await ctx.send(f"{user.mention} has been approved.", ephemeral=True) + # Allow the use of a reaction to approve a user. + # This is mainly for compatibility with NLL. + # The permission thingy should probably be reworked, as it currently allows anyone + # with the manage roles permission to use this. + # TODO: Rewrite this to require a specific role. @listen(events.MessageReactionAdd) async def on_reaction_add(self, reaction: events.MessageReactionAdd): if not reaction.emoji.name in [ diff --git a/commands/infractions.py b/commands/infractions.py index 2bbb1d9..56a6296 100644 --- a/commands/infractions.py +++ b/commands/infractions.py @@ -211,18 +211,18 @@ class Infractions(Extension): # TODO: Add this in again when GuildSettings is implemented - # guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id)) - # if guild_settings is not None: - # if guild_settings.admin_channel is not None: - # admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel)) - # if admin_channel is not None: - # await admin_channel.send(embed=Embed( - # title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})", - # description=f"{reason}", - # color=infraction_colour(0x0000FF), - # fields=[ - # ], - # )) + guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id)) + if guild_settings is not None: + if guild_settings.admin_channel is not None: + admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel)) + if admin_channel is not None: + await admin_channel.send(embed=Embed( + title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})", + description=f"{reason}", + color=infraction_colour(0x0000FF), + fields=[ + ], + )) if not silent and warning_msg is None: await ctx.send( diff --git a/commands/polls.py b/commands/polls.py index 73f47b0..f1bad19 100644 --- a/commands/polls.py +++ b/commands/polls.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta import logging import json +from time import sleep from typing import Dict, List, Tuple from naff import ( Client, @@ -34,12 +35,24 @@ from peewee import fn PollOptions = List[Tuple[str | None, str]] -def datetime_to_discord_time(dt: datetime) -> str: +def datetime_to_discord_relative_time(dt: datetime) -> str: + """Create a Discord relative time text from a datetime object.""" t = dt.strftime("%s") return f"" def generate_bar(num: int, total: int, length: int = 10) -> str: + """Create a bar graph from a number and a total. + + Parameters + ---------- + num : int + The current amount. + total : int + The total amount. + length : int + The amount of characters to use for the bar. + """ full_char = "\u2593" empty_char = "\u2591" @@ -63,11 +76,27 @@ def generate_poll_embed( multiple_choice: bool = False, expires: datetime = None, ) -> Embed: + """Create a poll embed from a title, options and votes. + + Parameters + ---------- + title : str + The title of the poll. + options : PollOptions + The options of the poll. + votes : List[int] + The votes of the poll, in the same order as the options. + multiple_choice : bool, optional + Whether the poll is multiple choice. Defaults to False. + expires : datetime, optional + The time at which the poll expires. Defaults to None. + """ data = [] + # \u2022 is a bullet point if multiple_choice: data.append("\u2022 Multiple choice") if expires: - data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}") + data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}") embed = Embed( title=title, @@ -75,6 +104,7 @@ def generate_poll_embed( ) sum_votes = sum(votes) + # Add a field for each option, with the vote count and a bar graph showing the percentage. for i, (emoji, option) in enumerate(options): embed.add_field( name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**", @@ -178,22 +208,31 @@ class Polls(Extension): ], ) + # Display the modal and wait for the user to submit. await ctx.send_modal(modal) modal_ctx: ModalContext = await self.client.wait_for_modal( modal=modal, author=ctx.author ) + # If the user set a duration for the poll, create a datetime for the time in the + # future when the poll expires. duration: datetime | None = ( (datetime.now() + timedelta(minutes=duration)) if duration else None ) title = modal_ctx.responses["title"] options: PollOptions = [] + + # Get each option from the poll options. + # We're stripping the first occurance of a dash, as it otherwise would be included. for i, option in enumerate( modal_ctx.responses["options"].replace("-", "", 1).split("\n-") ): + # If the option is empty, skip it. if option == "": continue + + # Check if the option contains an emoji. parts = option.split(":", 1) if len(parts) == 1: options.append((None, parts[0].strip())) @@ -217,6 +256,7 @@ class Polls(Extension): expires=duration, ) + # Create vote buttons for each option. buttons: List[Button] = [] for i, option in enumerate(options): try: @@ -231,6 +271,7 @@ class Polls(Extension): ) ) + # Create a button to allow locking the poll. buttons.append( Button( emoji="🔒", @@ -251,9 +292,11 @@ class Polls(Extension): embed=embed, components=spread_to_rows(*buttons), ) + # Naive error handling. except HTTPException as e: logging.error(f"Error sending poll: {e}") await modal_ctx.send( + # TODO: This should probably also include a sample for poll options. "Error during poll creation. NB: You can not use server-specific emojis", ephemeral=True, ) @@ -268,7 +311,9 @@ class Polls(Extension): ctx = button.context await ctx.defer(ephemeral=True) + # Ensure that the pressed button is a vote button. if ctx.custom_id.startswith("poll-vote:"): + # Get the poll ID and the option index. poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1) poll_entry: PollsModel | None = PollsModel.get_or_none( @@ -285,9 +330,12 @@ class Polls(Extension): PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) + # If the user somehow already has more than one vote, delete them. + # This should never happen, but just in case. + # Then, add the new vote. if votes_q.count() > 1: for vote in votes_q: - vote.delete().execute() + vote.delete_instance() PollVotesModel.create( poll_id=poll_id, @@ -295,13 +343,16 @@ class Polls(Extension): option=option_num, ) elif votes_q.count() == 1: + # If the vote is the current vote, delete it. if int(votes_q[0].option) == int(option_num): votes_q[0].delete_instance() await ctx.send("You have removed your vote.") + # If it's not the current vote, change the vote to the new one. else: votes_q[0].option = option_num votes_q[0].save() await ctx.send("You have changed your vote.") + #If the user has no votes, add a new vote. else: PollVotesModel.create( poll_id=poll_id, @@ -309,12 +360,15 @@ class Polls(Extension): option=option_num, ) await ctx.send("You have voted.") + + # If the poll is multiple choice else: votes_q: List[PollVotesModel] = PollVotesModel.select().where( PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) + # If the user has already voted for this option, remove their vote. exists = False for vote in votes_q: if int(vote.option) == (option_num): @@ -323,6 +377,7 @@ class Polls(Extension): await ctx.send("You have removed your vote.") break + # If the user has not voted for this option, add a new vote. if not exists: PollVotesModel.create( poll_id=poll_id, @@ -360,8 +415,10 @@ class Polls(Extension): expires=poll_entry.expires, ) + # Edit the message with the updated information. await ctx.message.edit(embed=embed) + # If the "lock poll" button is pressed, lock the poll. elif ctx.custom_id.startswith("poll-lock:"): poll_id = ctx.custom_id.split(":", 1)[1] @@ -372,18 +429,22 @@ class Polls(Extension): await ctx.send("That poll doesn't exist.") return + # Ensure that the user is the poll creator, or can manage messages. if not ctx.author.id == int( poll_entry.author_id ) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES): await ctx.send("You don't have permission to lock that poll.") return + # Set the poll to be expired to lock it. poll_entry.expires = datetime.now() - timedelta(minutes=1) poll_entry.save() + # Force the "poll expiry check" task to run. await self.poll_expiry_check() await ctx.send("Poll locked.") + # A task that runs each minute to check for expired polls. @Task.create(IntervalTrigger(minutes=1)) async def poll_expiry_check(self): logging.info("Checking for expired polls.") @@ -402,6 +463,10 @@ class Polls(Extension): continue await message.edit(components=[]) + # Delete associated database entries, as they will no longer be updated. + PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute() + poll_entry.delete_instance() + def setup(client: Client):