Added comments for readability
This commit is contained in:
parent
f1987d2788
commit
285c617871
37
Heimdallr.py
37
Heimdallr.py
|
@ -121,18 +121,47 @@ async def bot_info_command(ctx: InteractionContext):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_loglevel(level: str):
|
||||||
|
loglevel = logging.WARNING
|
||||||
|
|
||||||
if __name__ == "__main__":
|
match level.lower():
|
||||||
logging.basicConfig(level=logging.INFO)
|
case "d", "debug":
|
||||||
|
loglevel = logging.DEBUG
|
||||||
|
case "i", "info", "information":
|
||||||
|
loglevel = logging.INFO
|
||||||
|
|
||||||
|
case "w", "warn", "warning":
|
||||||
|
loglevel = logging.WARNING
|
||||||
|
|
||||||
|
case "e", "error":
|
||||||
|
loglevel = logging.ERROR
|
||||||
|
|
||||||
|
case "c", "critical":
|
||||||
|
loglevel = logging.CRITICAL
|
||||||
|
|
||||||
|
case _:
|
||||||
|
loglevel = logging.WARNING
|
||||||
|
|
||||||
|
logging.basicConfig(level=loglevel)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
load_dotenv() # Load environment variables from .env file
|
||||||
|
set_loglevel(getenv("HEIMDALLR_LOGLEVEL"))
|
||||||
|
|
||||||
|
# Create basic tables
|
||||||
GuildSettingsModel.create_table()
|
GuildSettingsModel.create_table()
|
||||||
JoinLeave.create_table()
|
JoinLeave.create_table()
|
||||||
|
|
||||||
load_dotenv()
|
# Load extensions
|
||||||
bot.load_extension("commands.admin")
|
bot.load_extension("commands.admin")
|
||||||
bot.load_extension("commands.gatekeep")
|
bot.load_extension("commands.gatekeep")
|
||||||
bot.load_extension("commands.quote")
|
bot.load_extension("commands.quote")
|
||||||
bot.load_extension("commands.infractions")
|
bot.load_extension("commands.infractions")
|
||||||
bot.load_extension("commands.self_roles")
|
bot.load_extension("commands.self_roles")
|
||||||
bot.load_extension("commands.polls")
|
bot.load_extension("commands.polls")
|
||||||
bot.load_extension("commands.bot_messages")
|
bot.load_extension("commands.bot_messages")
|
||||||
bot.start(getenv("DISCORD_TOKEN"))
|
bot.start(getenv("DISCORD_TOKEN"))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -29,10 +29,13 @@ class Admin(Extension):
|
||||||
default_member_permissions=Permissions.MANAGE_GUILD,
|
default_member_permissions=Permissions.MANAGE_GUILD,
|
||||||
)
|
)
|
||||||
async def adm_list(self, ctx: InteractionContext) -> None:
|
async def adm_list(self, ctx: InteractionContext) -> None:
|
||||||
|
# Get the guild's settings from the database.
|
||||||
guild_settings: Optional[GuildSettingsModel]
|
guild_settings: Optional[GuildSettingsModel]
|
||||||
guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id)
|
guild_settings, _ = GuildSettingsModel.get_or_create(guild_id=ctx.guild_id)
|
||||||
joinleave_settings: Optional[GuildSettingsModel]
|
joinleave_settings: Optional[GuildSettingsModel]
|
||||||
joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id)
|
joinleave_settings, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild_id)
|
||||||
|
|
||||||
|
# Create an embed to display settings.
|
||||||
embed = Embed(
|
embed = Embed(
|
||||||
title="Settings for {}".format(ctx.guild.name),
|
title="Settings for {}".format(ctx.guild.name),
|
||||||
fields=[
|
fields=[
|
||||||
|
|
|
@ -19,7 +19,7 @@ from naff import (
|
||||||
|
|
||||||
from database import BotMessages as BotMessagesModel
|
from database import BotMessages as BotMessagesModel
|
||||||
|
|
||||||
|
# Template modal for creating/editing bot messages.
|
||||||
message_creation_modal = Modal(
|
message_creation_modal = Modal(
|
||||||
custom_id="bot-message-create",
|
custom_id="bot-message-create",
|
||||||
title=f"Create a message as the bot",
|
title=f"Create a message as the bot",
|
||||||
|
@ -44,6 +44,7 @@ class BotMessages(Extension):
|
||||||
def __init__(self, client: Client) -> None:
|
def __init__(self, client: Client) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
# Create a new bot message.
|
||||||
@slash_command(
|
@slash_command(
|
||||||
name="bot-message-create",
|
name="bot-message-create",
|
||||||
description="Create a message as the bot.",
|
description="Create a message as the bot.",
|
||||||
|
@ -52,13 +53,19 @@ class BotMessages(Extension):
|
||||||
)
|
)
|
||||||
async def bot_message_create_command(self, ctx: InteractionContext):
|
async def bot_message_create_command(self, ctx: InteractionContext):
|
||||||
|
|
||||||
|
# Respond with the template modal. No values have been set in the modal, as it is
|
||||||
|
# a new message.
|
||||||
await ctx.send_modal(message_creation_modal)
|
await ctx.send_modal(message_creation_modal)
|
||||||
|
|
||||||
|
# Wait for the user to submit the modal, and ensure that we are receiving the
|
||||||
|
# correct modal.
|
||||||
modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author)
|
modal_ctx: ModalContext = await self.client.wait_for_modal(message_creation_modal, author=ctx.author)
|
||||||
if modal_ctx.custom_id != "bot-message-create":
|
if modal_ctx.custom_id != "bot-message-create":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Retrieve the values from the modal.
|
||||||
|
# Ensure that at least one of the fields is filled in, as a message cannot be
|
||||||
|
# empty.
|
||||||
embeds_string: str = modal_ctx.responses["embeds"]
|
embeds_string: str = modal_ctx.responses["embeds"]
|
||||||
content_string: str = modal_ctx.responses["content"]
|
content_string: str = modal_ctx.responses["content"]
|
||||||
if (
|
if (
|
||||||
|
@ -71,6 +78,10 @@ class BotMessages(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Attempt to parse the embed(s) JSON into a Python object.
|
||||||
|
# Try loading it as a single or multiple embeds depending on if the object
|
||||||
|
# returned is a list or a dictionary.
|
||||||
|
# If the JSON is invalid, return an error.
|
||||||
embed: dict | None = None
|
embed: dict | None = None
|
||||||
embeds: list | None = None
|
embeds: list | None = None
|
||||||
|
|
||||||
|
@ -88,21 +99,27 @@ class BotMessages(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Send the bot message in the channel.
|
||||||
msg = await ctx.channel.send(
|
msg = await ctx.channel.send(
|
||||||
content=content_string if content_string else None,
|
content=content_string if content_string else None,
|
||||||
embed=embed,
|
embed=embed,
|
||||||
embeds=embeds,
|
embeds=embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add an entry of the message in the database.
|
||||||
BotMessagesModel.create(
|
BotMessagesModel.create(
|
||||||
guild_id=msg.guild.id,
|
guild_id=msg.guild.id,
|
||||||
channel_id=msg.channel.id,
|
channel_id=msg.channel.id,
|
||||||
message_id=msg.id,
|
message_id=msg.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Send a confirmation message, as Discord requires us to respond to the interaction.
|
||||||
await modal_ctx.send(
|
await modal_ctx.send(
|
||||||
"Message created!",
|
"Message created!",
|
||||||
ephemeral=True,
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A context menu to allow moderators to edit a bot message.
|
||||||
@context_menu(
|
@context_menu(
|
||||||
name="Edit bot message",
|
name="Edit bot message",
|
||||||
context_type=CommandTypes.MESSAGE,
|
context_type=CommandTypes.MESSAGE,
|
||||||
|
@ -111,6 +128,9 @@ class BotMessages(Extension):
|
||||||
)
|
)
|
||||||
async def edit_bot_message_context_menu(self, ctx: InteractionContext):
|
async def edit_bot_message_context_menu(self, ctx: InteractionContext):
|
||||||
message: Message = ctx.target
|
message: Message = ctx.target
|
||||||
|
|
||||||
|
# Ensure that the target message is from the bot.
|
||||||
|
# If it is not, return an error.
|
||||||
if message.author.id != self.client.user.id:
|
if message.author.id != self.client.user.id:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"This is not a bot message.",
|
"This is not a bot message.",
|
||||||
|
@ -118,30 +138,38 @@ class BotMessages(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
bot_message: BotMessagesModel | None = BotMessagesModel.get(
|
# Retrieve the bot message from the database, if any.
|
||||||
|
bot_message: BotMessagesModel | None = BotMessagesModel.get_or_none(
|
||||||
BotMessagesModel.guild_id == message.channel.guild.id,
|
BotMessagesModel.guild_id == message.channel.guild.id,
|
||||||
BotMessagesModel.channel_id == message.channel.id,
|
BotMessagesModel.channel_id == message.channel.id,
|
||||||
BotMessagesModel.message_id == message.id,
|
BotMessagesModel.message_id == message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If there is no bot message, return an error.
|
||||||
if bot_message is None:
|
if bot_message is None:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"This is not an editable bot message.",
|
"This is not an editable bot message.",
|
||||||
ephemeral=True,
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Create a copy of the template modal, and insert the contents of the bot message.
|
||||||
modal = deepcopy(message_creation_modal)
|
modal = deepcopy(message_creation_modal)
|
||||||
modal.title = "Edit bot message"
|
modal.title = "Edit bot message"
|
||||||
|
modal.custom_id = "bot-message-edit"
|
||||||
modal.components[0].value = json.dumps(
|
modal.components[0].value = json.dumps(
|
||||||
[e.to_dict() for e in message.embeds] if message.embeds else "",
|
[e.to_dict() for e in message.embeds] if message.embeds else "",
|
||||||
indent=4,
|
indent=4,
|
||||||
)
|
)
|
||||||
modal.components[1].value = message.content
|
modal.components[1].value = message.content
|
||||||
|
|
||||||
|
# Send the modal to the user
|
||||||
await ctx.send_modal(modal)
|
await ctx.send_modal(modal)
|
||||||
|
|
||||||
|
# Wait for the user to submit the modal, and ensure that we are receiving the
|
||||||
|
# correct modal.
|
||||||
modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author)
|
modal_ctx: ModalContext = await self.client.wait_for_modal(modal, author=ctx.author)
|
||||||
if modal_ctx.custom_id != "bot-message-create":
|
if modal_ctx.custom_id != "bot-message-edit":
|
||||||
return
|
return
|
||||||
|
|
||||||
embeds_string: str = modal_ctx.responses["embeds"]
|
embeds_string: str = modal_ctx.responses["embeds"]
|
||||||
|
|
|
@ -128,8 +128,10 @@ class Gatekeep(Extension):
|
||||||
jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id)
|
jl, _ = JoinLeaveModel.get_or_create(guild_id=ctx.guild.id)
|
||||||
await user.add_role(int(gk.gatekeep_approve_role))
|
await user.add_role(int(gk.gatekeep_approve_role))
|
||||||
|
|
||||||
|
# Check if a welcome channel is set
|
||||||
welcome_channel = not jl.message_channel is None
|
welcome_channel = not jl.message_channel is None
|
||||||
|
|
||||||
|
# If there is no approval message set, inform the issuer privately.
|
||||||
if gk.gatekeep_approve_message is None:
|
if gk.gatekeep_approve_message is None:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
f"{user.mention} has been approved.\nNB: No approval message has been sent.",
|
f"{user.mention} has been approved.\nNB: No approval message has been sent.",
|
||||||
|
@ -137,6 +139,7 @@ class Gatekeep(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If there is no welcome channel set, attempt to DM the approval message to the user.
|
||||||
if not welcome_channel:
|
if not welcome_channel:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
f"{user.mention} has been approved.\nNB: No welcome channel has been set – attempting to DM {user.mention}",
|
f"{user.mention} has been approved.\nNB: No welcome channel has been set – attempting to DM {user.mention}",
|
||||||
|
@ -147,6 +150,7 @@ class Gatekeep(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# DM the user if the bot fails to retrieve the welcome channel.
|
||||||
channel = await ctx.guild.fetch_channel(jl.message_channel)
|
channel = await ctx.guild.fetch_channel(jl.message_channel)
|
||||||
if not channel:
|
if not channel:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
|
@ -158,6 +162,7 @@ class Gatekeep(Extension):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If none of the above occur, finally send the approval message to the welcome channel.
|
||||||
await channel.send(
|
await channel.send(
|
||||||
str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild)
|
str(gk.gatekeep_approve_message).format(member=user, guild=ctx.guild)
|
||||||
)
|
)
|
||||||
|
@ -212,6 +217,11 @@ class Gatekeep(Extension):
|
||||||
)
|
)
|
||||||
await ctx.send(f"{user.mention} has been approved.", ephemeral=True)
|
await ctx.send(f"{user.mention} has been approved.", ephemeral=True)
|
||||||
|
|
||||||
|
# Allow the use of a reaction to approve a user.
|
||||||
|
# This is mainly for compatibility with NLL.
|
||||||
|
# The permission thingy should probably be reworked, as it currently allows anyone
|
||||||
|
# with the manage roles permission to use this.
|
||||||
|
# TODO: Rewrite this to require a specific role.
|
||||||
@listen(events.MessageReactionAdd)
|
@listen(events.MessageReactionAdd)
|
||||||
async def on_reaction_add(self, reaction: events.MessageReactionAdd):
|
async def on_reaction_add(self, reaction: events.MessageReactionAdd):
|
||||||
if not reaction.emoji.name in [
|
if not reaction.emoji.name in [
|
||||||
|
|
|
@ -211,18 +211,18 @@ class Infractions(Extension):
|
||||||
|
|
||||||
# TODO: Add this in again when GuildSettings is implemented
|
# TODO: Add this in again when GuildSettings is implemented
|
||||||
|
|
||||||
# guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id))
|
guild_settings: Optional[GuildSettings] = GuildSettings.get_or_none(GuildSettings.guild_id == int(ctx.guild_id))
|
||||||
# if guild_settings is not None:
|
if guild_settings is not None:
|
||||||
# if guild_settings.admin_channel is not None:
|
if guild_settings.admin_channel is not None:
|
||||||
# admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel))
|
admin_channel = self.client.fetch_channel(int(guild_settings.admin_channel))
|
||||||
# if admin_channel is not None:
|
if admin_channel is not None:
|
||||||
# await admin_channel.send(embed=Embed(
|
await admin_channel.send(embed=Embed(
|
||||||
# title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})",
|
title=f"Warned {user.display_name} ({user.username}#{user.discriminator}, {user.id})",
|
||||||
# description=f"{reason}",
|
description=f"{reason}",
|
||||||
# color=infraction_colour(0x0000FF),
|
color=infraction_colour(0x0000FF),
|
||||||
# fields=[
|
fields=[
|
||||||
# ],
|
],
|
||||||
# ))
|
))
|
||||||
|
|
||||||
if not silent and warning_msg is None:
|
if not silent and warning_msg is None:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
from time import sleep
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
from naff import (
|
from naff import (
|
||||||
Client,
|
Client,
|
||||||
|
@ -34,12 +35,24 @@ from peewee import fn
|
||||||
PollOptions = List[Tuple[str | None, str]]
|
PollOptions = List[Tuple[str | None, str]]
|
||||||
|
|
||||||
|
|
||||||
def datetime_to_discord_time(dt: datetime) -> str:
|
def datetime_to_discord_relative_time(dt: datetime) -> str:
|
||||||
|
"""Create a Discord relative time text from a datetime object."""
|
||||||
t = dt.strftime("%s")
|
t = dt.strftime("%s")
|
||||||
return f"<t:{int(t)}:R>"
|
return f"<t:{int(t)}:R>"
|
||||||
|
|
||||||
|
|
||||||
def generate_bar(num: int, total: int, length: int = 10) -> str:
|
def generate_bar(num: int, total: int, length: int = 10) -> str:
|
||||||
|
"""Create a bar graph from a number and a total.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num : int
|
||||||
|
The current amount.
|
||||||
|
total : int
|
||||||
|
The total amount.
|
||||||
|
length : int
|
||||||
|
The amount of characters to use for the bar.
|
||||||
|
"""
|
||||||
full_char = "\u2593"
|
full_char = "\u2593"
|
||||||
empty_char = "\u2591"
|
empty_char = "\u2591"
|
||||||
|
|
||||||
|
@ -63,11 +76,27 @@ def generate_poll_embed(
|
||||||
multiple_choice: bool = False,
|
multiple_choice: bool = False,
|
||||||
expires: datetime = None,
|
expires: datetime = None,
|
||||||
) -> Embed:
|
) -> Embed:
|
||||||
|
"""Create a poll embed from a title, options and votes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
title : str
|
||||||
|
The title of the poll.
|
||||||
|
options : PollOptions
|
||||||
|
The options of the poll.
|
||||||
|
votes : List[int]
|
||||||
|
The votes of the poll, in the same order as the options.
|
||||||
|
multiple_choice : bool, optional
|
||||||
|
Whether the poll is multiple choice. Defaults to False.
|
||||||
|
expires : datetime, optional
|
||||||
|
The time at which the poll expires. Defaults to None.
|
||||||
|
"""
|
||||||
data = []
|
data = []
|
||||||
|
# \u2022 is a bullet point
|
||||||
if multiple_choice:
|
if multiple_choice:
|
||||||
data.append("\u2022 Multiple choice")
|
data.append("\u2022 Multiple choice")
|
||||||
if expires:
|
if expires:
|
||||||
data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}")
|
data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}")
|
||||||
|
|
||||||
embed = Embed(
|
embed = Embed(
|
||||||
title=title,
|
title=title,
|
||||||
|
@ -75,6 +104,7 @@ def generate_poll_embed(
|
||||||
)
|
)
|
||||||
sum_votes = sum(votes)
|
sum_votes = sum(votes)
|
||||||
|
|
||||||
|
# Add a field for each option, with the vote count and a bar graph showing the percentage.
|
||||||
for i, (emoji, option) in enumerate(options):
|
for i, (emoji, option) in enumerate(options):
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
|
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
|
||||||
|
@ -178,22 +208,31 @@ class Polls(Extension):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Display the modal and wait for the user to submit.
|
||||||
await ctx.send_modal(modal)
|
await ctx.send_modal(modal)
|
||||||
modal_ctx: ModalContext = await self.client.wait_for_modal(
|
modal_ctx: ModalContext = await self.client.wait_for_modal(
|
||||||
modal=modal, author=ctx.author
|
modal=modal, author=ctx.author
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If the user set a duration for the poll, create a datetime for the time in the
|
||||||
|
# future when the poll expires.
|
||||||
duration: datetime | None = (
|
duration: datetime | None = (
|
||||||
(datetime.now() + timedelta(minutes=duration)) if duration else None
|
(datetime.now() + timedelta(minutes=duration)) if duration else None
|
||||||
)
|
)
|
||||||
|
|
||||||
title = modal_ctx.responses["title"]
|
title = modal_ctx.responses["title"]
|
||||||
options: PollOptions = []
|
options: PollOptions = []
|
||||||
|
|
||||||
|
# Get each option from the poll options.
|
||||||
|
# We're stripping the first occurance of a dash, as it otherwise would be included.
|
||||||
for i, option in enumerate(
|
for i, option in enumerate(
|
||||||
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
|
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
|
||||||
):
|
):
|
||||||
|
# If the option is empty, skip it.
|
||||||
if option == "":
|
if option == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if the option contains an emoji.
|
||||||
parts = option.split(":", 1)
|
parts = option.split(":", 1)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
options.append((None, parts[0].strip()))
|
options.append((None, parts[0].strip()))
|
||||||
|
@ -217,6 +256,7 @@ class Polls(Extension):
|
||||||
expires=duration,
|
expires=duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create vote buttons for each option.
|
||||||
buttons: List[Button] = []
|
buttons: List[Button] = []
|
||||||
for i, option in enumerate(options):
|
for i, option in enumerate(options):
|
||||||
try:
|
try:
|
||||||
|
@ -231,6 +271,7 @@ class Polls(Extension):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a button to allow locking the poll.
|
||||||
buttons.append(
|
buttons.append(
|
||||||
Button(
|
Button(
|
||||||
emoji="🔒",
|
emoji="🔒",
|
||||||
|
@ -251,9 +292,11 @@ class Polls(Extension):
|
||||||
embed=embed,
|
embed=embed,
|
||||||
components=spread_to_rows(*buttons),
|
components=spread_to_rows(*buttons),
|
||||||
)
|
)
|
||||||
|
# Naive error handling.
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
logging.error(f"Error sending poll: {e}")
|
logging.error(f"Error sending poll: {e}")
|
||||||
await modal_ctx.send(
|
await modal_ctx.send(
|
||||||
|
# TODO: This should probably also include a sample for poll options.
|
||||||
"Error during poll creation. NB: You can not use server-specific emojis",
|
"Error during poll creation. NB: You can not use server-specific emojis",
|
||||||
ephemeral=True,
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
|
@ -268,7 +311,9 @@ class Polls(Extension):
|
||||||
ctx = button.context
|
ctx = button.context
|
||||||
await ctx.defer(ephemeral=True)
|
await ctx.defer(ephemeral=True)
|
||||||
|
|
||||||
|
# Ensure that the pressed button is a vote button.
|
||||||
if ctx.custom_id.startswith("poll-vote:"):
|
if ctx.custom_id.startswith("poll-vote:"):
|
||||||
|
# Get the poll ID and the option index.
|
||||||
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
|
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
|
||||||
|
|
||||||
poll_entry: PollsModel | None = PollsModel.get_or_none(
|
poll_entry: PollsModel | None = PollsModel.get_or_none(
|
||||||
|
@ -285,9 +330,12 @@ class Polls(Extension):
|
||||||
PollVotesModel.poll_id == poll_id,
|
PollVotesModel.poll_id == poll_id,
|
||||||
PollVotesModel.user_id == ctx.author.id,
|
PollVotesModel.user_id == ctx.author.id,
|
||||||
)
|
)
|
||||||
|
# If the user somehow already has more than one vote, delete them.
|
||||||
|
# This should never happen, but just in case.
|
||||||
|
# Then, add the new vote.
|
||||||
if votes_q.count() > 1:
|
if votes_q.count() > 1:
|
||||||
for vote in votes_q:
|
for vote in votes_q:
|
||||||
vote.delete().execute()
|
vote.delete_instance()
|
||||||
|
|
||||||
PollVotesModel.create(
|
PollVotesModel.create(
|
||||||
poll_id=poll_id,
|
poll_id=poll_id,
|
||||||
|
@ -295,13 +343,16 @@ class Polls(Extension):
|
||||||
option=option_num,
|
option=option_num,
|
||||||
)
|
)
|
||||||
elif votes_q.count() == 1:
|
elif votes_q.count() == 1:
|
||||||
|
# If the vote is the current vote, delete it.
|
||||||
if int(votes_q[0].option) == int(option_num):
|
if int(votes_q[0].option) == int(option_num):
|
||||||
votes_q[0].delete_instance()
|
votes_q[0].delete_instance()
|
||||||
await ctx.send("You have removed your vote.")
|
await ctx.send("You have removed your vote.")
|
||||||
|
# If it's not the current vote, change the vote to the new one.
|
||||||
else:
|
else:
|
||||||
votes_q[0].option = option_num
|
votes_q[0].option = option_num
|
||||||
votes_q[0].save()
|
votes_q[0].save()
|
||||||
await ctx.send("You have changed your vote.")
|
await ctx.send("You have changed your vote.")
|
||||||
|
#If the user has no votes, add a new vote.
|
||||||
else:
|
else:
|
||||||
PollVotesModel.create(
|
PollVotesModel.create(
|
||||||
poll_id=poll_id,
|
poll_id=poll_id,
|
||||||
|
@ -309,12 +360,15 @@ class Polls(Extension):
|
||||||
option=option_num,
|
option=option_num,
|
||||||
)
|
)
|
||||||
await ctx.send("You have voted.")
|
await ctx.send("You have voted.")
|
||||||
|
|
||||||
|
# If the poll is multiple choice
|
||||||
else:
|
else:
|
||||||
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
|
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
|
||||||
PollVotesModel.poll_id == poll_id,
|
PollVotesModel.poll_id == poll_id,
|
||||||
PollVotesModel.user_id == ctx.author.id,
|
PollVotesModel.user_id == ctx.author.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If the user has already voted for this option, remove their vote.
|
||||||
exists = False
|
exists = False
|
||||||
for vote in votes_q:
|
for vote in votes_q:
|
||||||
if int(vote.option) == (option_num):
|
if int(vote.option) == (option_num):
|
||||||
|
@ -323,6 +377,7 @@ class Polls(Extension):
|
||||||
await ctx.send("You have removed your vote.")
|
await ctx.send("You have removed your vote.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# If the user has not voted for this option, add a new vote.
|
||||||
if not exists:
|
if not exists:
|
||||||
PollVotesModel.create(
|
PollVotesModel.create(
|
||||||
poll_id=poll_id,
|
poll_id=poll_id,
|
||||||
|
@ -360,8 +415,10 @@ class Polls(Extension):
|
||||||
expires=poll_entry.expires,
|
expires=poll_entry.expires,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Edit the message with the updated information.
|
||||||
await ctx.message.edit(embed=embed)
|
await ctx.message.edit(embed=embed)
|
||||||
|
|
||||||
|
# If the "lock poll" button is pressed, lock the poll.
|
||||||
elif ctx.custom_id.startswith("poll-lock:"):
|
elif ctx.custom_id.startswith("poll-lock:"):
|
||||||
poll_id = ctx.custom_id.split(":", 1)[1]
|
poll_id = ctx.custom_id.split(":", 1)[1]
|
||||||
|
|
||||||
|
@ -372,18 +429,22 @@ class Polls(Extension):
|
||||||
await ctx.send("That poll doesn't exist.")
|
await ctx.send("That poll doesn't exist.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Ensure that the user is the poll creator, or can manage messages.
|
||||||
if not ctx.author.id == int(
|
if not ctx.author.id == int(
|
||||||
poll_entry.author_id
|
poll_entry.author_id
|
||||||
) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES):
|
) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES):
|
||||||
await ctx.send("You don't have permission to lock that poll.")
|
await ctx.send("You don't have permission to lock that poll.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Set the poll to be expired to lock it.
|
||||||
poll_entry.expires = datetime.now() - timedelta(minutes=1)
|
poll_entry.expires = datetime.now() - timedelta(minutes=1)
|
||||||
poll_entry.save()
|
poll_entry.save()
|
||||||
|
|
||||||
|
# Force the "poll expiry check" task to run.
|
||||||
await self.poll_expiry_check()
|
await self.poll_expiry_check()
|
||||||
await ctx.send("Poll locked.")
|
await ctx.send("Poll locked.")
|
||||||
|
|
||||||
|
# A task that runs each minute to check for expired polls.
|
||||||
@Task.create(IntervalTrigger(minutes=1))
|
@Task.create(IntervalTrigger(minutes=1))
|
||||||
async def poll_expiry_check(self):
|
async def poll_expiry_check(self):
|
||||||
logging.info("Checking for expired polls.")
|
logging.info("Checking for expired polls.")
|
||||||
|
@ -402,6 +463,10 @@ class Polls(Extension):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await message.edit(components=[])
|
await message.edit(components=[])
|
||||||
|
# Delete associated database entries, as they will no longer be updated.
|
||||||
|
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute()
|
||||||
|
poll_entry.delete_instance()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup(client: Client):
|
def setup(client: Client):
|
||||||
|
|
Loading…
Reference in New Issue