Heimdallr/heimdallr/commands/polls.py

479 lines
16 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pylint: disable=not-an-iterable
# pylint: disable=unsubscriptable-object
# pylint: disable=logging-fstring-interpolation
from datetime import datetime, timedelta
import logging
import json
from typing import List, Tuple
from naff import (
Client,
Extension,
slash_command,
slash_option,
InteractionContext,
OptionTypes,
Permissions,
Modal,
ShortText,
ParagraphText,
ModalContext,
Button,
ButtonStyles,
Embed,
spread_to_rows,
Message,
listen,
Task,
IntervalTrigger,
GuildText,
PartialEmoji,
)
from naff.api import events
from naff.client.errors import HTTPException
from peewee import fn
from database import Polls as PollsModel, PollVotes as PollVotesModel
PollOptions = List[Tuple[str | None, str]]
def datetime_to_discord_relative_time(dt: datetime) -> str:
"""Create a Discord relative time text from a datetime object."""
t = dt.strftime("%s")
return f"<t:{int(t)}:R>"
def generate_bar(num: int, total: int, length: int = 10) -> str:
"""Create a bar graph from a number and a total.
Parameters
----------
num : int
The current amount.
total : int
The total amount.
length : int
The amount of characters to use for the bar.
"""
full_char = "\u2593"
empty_char = "\u2591"
if total <= 0:
return length * empty_char
result = round(num / total * length)
percent = num / total * 100
return (
result * full_char
+ (length - result) * empty_char
+ f" ({percent: 3.1f}%)".replace(".0", "")
)
def generate_poll_embed(
title: str,
options: PollOptions,
votes: List[int],
*,
multiple_choice: bool = False,
expires: datetime = None,
) -> Embed:
"""Create a poll embed from a title, options and votes.
Parameters
----------
title : str
The title of the poll.
options : PollOptions
The options of the poll.
votes : List[int]
The votes of the poll, in the same order as the options.
multiple_choice : bool, optional
Whether the poll is multiple choice. Defaults to False.
expires : datetime, optional
The time at which the poll expires. Defaults to None.
"""
data = []
# \u2022 is a bullet point
if multiple_choice:
data.append("\u2022 Multiple choice")
if expires:
data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}")
embed = Embed(
title=title,
description=("\n".join(data) if data else None),
)
sum_votes = sum(votes)
# Add a field for each option, with the vote count and a bar graph showing the percentage.
for i, (emoji, option) in enumerate(options):
embed.add_field(
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes",
inline=False,
)
return embed
def num_to_emoji(num: int) -> str: # pylint: disable=too-many-return-statements
match num:
case 0:
return "0"
case 1:
return "1"
case 2:
return "2"
case 3:
return "3"
case 4:
return "4"
case 5:
return "5"
case 6:
return "6"
case 7:
return "7"
case 8:
return "8"
case 9:
return "9"
case 10:
return "🔟"
case _:
raise ValueError("Invalid number: `num` must be 0 <= num <= 10.")
class Polls(Extension):
def __init__(self, client: Client):
self.client = client
@listen(events.Ready)
async def on_ready(self):
await self.poll_expiry_check()
self.poll_expiry_check.start() # pylint: disable=no-member
@slash_command(
name="polls",
description="Polls",
sub_cmd_name="create",
sub_cmd_description="Create a poll",
dm_permission=False,
default_member_permissions=Permissions.SEND_MESSAGES,
)
@slash_option(
name="title",
description="Title of the poll",
required=True,
opt_type=OptionTypes.STRING,
)
@slash_option(
name="duration",
description="Duration of the poll in minutes",
required=False,
opt_type=OptionTypes.INTEGER,
)
@slash_option(
name="multiple-choice",
description="If users can vote for multiple options",
required=False,
opt_type=OptionTypes.BOOLEAN,
)
async def create_poll( # pylint: disable=too-many-locals
self,
ctx: InteractionContext,
*,
title: str,
duration: int | None = None,
multiple_choice: bool = False,
):
modal = Modal(
title="Creating poll",
components=[
ShortText(
custom_id="title",
label="Title",
value=title,
required=True,
max_length=120,
),
ParagraphText(
custom_id="options",
label="Poll options",
placeholder=(
"Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure"
),
required=True,
max_length=1200,
),
],
)
# Display the modal and wait for the user to submit.
await ctx.send_modal(modal)
modal_ctx: ModalContext = await self.client.wait_for_modal(
modal=modal, author=ctx.author
)
# If the user set a duration for the poll, create a datetime for the time in the
# future when the poll expires.
duration: datetime | None = (
(datetime.now() + timedelta(minutes=duration)) if duration else None
)
title = modal_ctx.responses["title"]
options: PollOptions = []
# Get each option from the poll options.
# We're stripping the first occurance of a dash, as it otherwise would be included.
for i, option in enumerate(
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
):
# If the option is empty, skip it.
if option == "":
continue
# Check if the option contains an emoji.
parts = option.split(":", 1)
if len(parts) == 1:
options.append((None, parts[0].strip()))
else:
options.append((parts[0].strip(), parts[1].strip()))
if len(options) > 10:
await modal_ctx.send("You can only have up to 10 options.", ephemeral=True)
return
if len(options) < 2:
await modal_ctx.send("You must have at least 2 options.", ephemeral=True)
return
poll_entry: PollsModel = PollsModel.create(
guild_id=ctx.guild.id,
author_id=ctx.author.id,
title=title,
options=json.dumps(options),
no_options=len(options),
multiple_choice=multiple_choice,
expires=duration,
)
# Create vote buttons for each option.
buttons: List[Button] = []
for i, option in enumerate(options):
try:
emoji = PartialEmoji.from_str((option[0] or num_to_emoji(i + 1)))
except ValueError:
emoji = num_to_emoji(i + 1)
buttons.append(
Button(
emoji=emoji,
style=ButtonStyles.PRIMARY,
custom_id=f"poll-vote:{poll_entry.id}:{i}",
)
)
# Create a button to allow locking the poll.
buttons.append(
Button(
emoji="🔒",
label="Lock",
style=ButtonStyles.DANGER,
custom_id=f"poll-lock:{poll_entry.id}",
)
)
embed = generate_poll_embed(
title,
options,
len(options) * [0],
multiple_choice=multiple_choice,
expires=duration,
)
try:
poll_message: Message = await modal_ctx.send(
embed=embed,
components=spread_to_rows(*buttons),
)
# Naive error handling.
except HTTPException as e:
logging.error(f"Error sending poll: {e}")
await modal_ctx.send(
# TODO: This should probably also include a sample for poll options.
"Error during poll creation. NB: You can not use server-specific emojis",
ephemeral=True,
)
return
poll_entry.message_id = poll_message.id
poll_entry.channel_id = poll_message.channel.id
poll_entry.save()
@listen(events.ButtonPressed)
async def on_button(self, button: events.ButtonPressed): #pylint: disable=too-many-branches,too-many-statements
ctx = button.ctx
# Ensure that the pressed button is a vote button.
if ctx.custom_id.startswith("poll-vote:"):
# Get the poll ID and the option index.
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
return
if poll_entry.expires and datetime.now() > poll_entry.expires:
return
await ctx.defer(ephemeral=True)
if not poll_entry.multiple_choice:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
# If the user somehow already has more than one vote, delete them.
# This should never happen, but just in case.
# Then, add the new vote.
if votes_q.count() > 1:
for vote in votes_q:
vote.delete_instance()
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
elif votes_q.count() == 1:
# If the vote is the current vote, delete it.
if int(votes_q[0].option) == int(option_num):
votes_q[0].delete_instance()
await ctx.send("You have removed your vote.")
# If it's not the current vote, change the vote to the new one.
else:
votes_q[0].option = option_num
votes_q[0].save()
await ctx.send("You have changed your vote.")
#If the user has no votes, add a new vote.
else:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
# If the poll is multiple choice
else:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
# If the user has already voted for this option, remove their vote.
exists = False
for vote in votes_q:
if int(vote.option) == (option_num):
exists = True
vote.delete_instance()
await ctx.send("You have removed your vote.")
break
# If the user has not voted for this option, add a new vote.
if not exists:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
votes_q: List[PollVotesModel] = (
PollVotesModel.select(
PollVotesModel.poll_id,
PollVotesModel.option,
fn.COUNT(PollVotesModel.option).alias("count"),
)
.where(PollVotesModel.poll_id == poll_id)
.group_by(PollVotesModel.option)
.order_by(PollVotesModel.option)
)
# This is such absolutely an awful way to do this. I'm sorry.
# It's the cost of not adding the options in a separate table, I guess.
# Anyway this just gets the votes for each option, and adds them to
# the list `votes`. They're in the same order as the options.
options: PollOptions = json.loads(poll_entry.options)
votes = len(options) * [0]
for vote in votes_q:
votes[int(vote.option)] = vote.count
embed = generate_poll_embed(
poll_entry.title,
options,
votes,
multiple_choice=poll_entry.multiple_choice,
expires=poll_entry.expires,
)
# Edit the message with the updated information.
await ctx.message.edit(embed=embed)
# If the "lock poll" button is pressed, lock the poll.
elif ctx.custom_id.startswith("poll-lock:"):
poll_id = ctx.custom_id.split(":", 1)[1]
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
await ctx.send("That poll doesn't exist.")
return
# Ensure that the user is the poll creator, or can manage messages.
if not ctx.author.id == int(
poll_entry.author_id
) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES):
await ctx.send("You don't have permission to lock that poll.")
return
# Set the poll to be expired to lock it.
poll_entry.expires = datetime.now() - timedelta(minutes=1)
poll_entry.save()
# Force the "poll expiry check" task to run.
await self.poll_expiry_check()
await ctx.send("Poll locked.")
# A task that runs each minute to check for expired polls.
@Task.create(IntervalTrigger(minutes=1))
async def poll_expiry_check(self):
logging.info("Checking for expired polls.")
now = datetime.now()
polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now)
for poll_entry in polls_q:
channel: GuildText = await self.client.fetch_channel(
int(poll_entry.channel_id)
)
if not channel:
continue
message: Message = await channel.fetch_message(int(poll_entry.message_id))
if not message:
continue
await message.edit(components=[])
# Delete associated database entries, as they will no longer be updated.
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute()
poll_entry.delete_instance()
def setup(client: Client):
PollsModel.create_table()
PollVotesModel.create_table()
Polls(client)
logging.info("Polls extension loaded")