Heimdallr/commands/polls.py

397 lines
12 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta
import logging
import json
from typing import Dict, List, Tuple
from naff import (
Client,
Extension,
slash_command,
slash_option,
InteractionContext,
OptionTypes,
Permissions,
Modal,
ShortText,
ParagraphText,
ModalContext,
Button,
ButtonStyles,
Embed,
spread_to_rows,
Message,
listen,
Task,
IntervalTrigger,
GuildText,
)
from naff.api import events
from database import Polls as PollsModel, PollVotes as PollVotesModel
from peewee import fn
PollOptions = List[Tuple[str | None, str]]
def datetime_to_discord_time(dt: datetime) -> str:
t = dt.strftime("%s")
return f"<t:{int(t)}:R>"
def generate_bar(num: int, total: int, length: int = 10) -> str:
full_char = "\u2593"
empty_char = "\u2591"
if total <= 0:
return length * empty_char
result = round(num / total * length)
percent = num / total * 100
return (
result * full_char
+ (length - result) * empty_char
+ f" ({percent: 3.1f}%)".replace(".0", "")
)
def generate_poll_embed(
title: str,
options: PollOptions,
votes: List[int],
*,
multiple_choice: bool = False,
expires: datetime = None,
) -> Embed:
data = []
if multiple_choice:
data.append("\u2022 Multiple choice")
if expires:
data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}")
embed = Embed(
title=title,
description=("\n".join(data) if data else None),
)
sum_votes = sum(votes)
for i, (emoji, option) in enumerate(options):
embed.add_field(
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes",
inline=False,
)
return embed
def num_to_emoji(num: int) -> str:
match num:
case 0:
return "0"
case 1:
return "1"
case 2:
return "2"
case 3:
return "3"
case 4:
return "4"
case 5:
return "5"
case 6:
return "6"
case 7:
return "7"
case 8:
return "8"
case 9:
return "9"
case 10:
return "🔟"
case _:
raise ValueError(f"Invalid number: `num` must be 0 <= num <= 10.")
class Polls(Extension):
def __init__(self, client: Client):
self.client = client
@listen(events.Ready)
async def on_ready(self):
await self.poll_expiry_check()
self.poll_expiry_check.start()
@slash_command(
name="polls",
description="Polls",
sub_cmd_name="create",
sub_cmd_description="Create a poll",
dm_permission=False,
default_member_permissions=Permissions.SEND_MESSAGES,
)
@slash_option(
name="title",
description="Title of the poll",
required=True,
opt_type=OptionTypes.STRING,
)
@slash_option(
name="duration",
description="Duration of the poll in minutes",
required=False,
opt_type=OptionTypes.INTEGER,
)
@slash_option(
name="multiple-choice",
description="If users can vote for multiple options",
required=False,
opt_type=OptionTypes.BOOLEAN,
)
async def create_poll(
self,
ctx: InteractionContext,
*,
title: str,
duration: int | None = None,
multiple_choice: bool = False,
):
modal = Modal(
title="Creating poll",
components=[
ShortText(
custom_id="title",
label="Title",
value=title,
required=True,
max_length=120,
),
ParagraphText(
custom_id="options",
label="Poll options",
placeholder=(
"Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure"
),
required=True,
max_length=1200,
),
],
)
await ctx.send_modal(modal)
modal_ctx: ModalContext = await self.client.wait_for_modal(
modal=modal, author=ctx.author
)
duration: datetime | None = (
(datetime.now() + timedelta(minutes=duration)) if duration else None
)
title = modal_ctx.responses["title"]
options: PollOptions = []
for i, option in enumerate(
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
):
if option == "":
continue
parts = option.split(":", 1)
if len(parts) == 1:
options.append((None, parts[0].strip()))
else:
options.append((parts[0].strip(), parts[1].strip()))
if len(options) > 10:
await modal_ctx.send("You can only have up to 10 options.", ephemeral=True)
return
if len(options) < 2:
await modal_ctx.send("You must have at least 2 options.", ephemeral=True)
return
poll_entry: PollsModel = PollsModel.create(
guild_id=ctx.guild.id,
author_id=ctx.author.id,
title=title,
options=json.dumps(options),
no_options=len(options),
multiple_choice=multiple_choice,
expires=duration,
)
buttons: List[Button] = []
for i, option in enumerate(options):
buttons.append(
Button(
emoji=(option[0] or num_to_emoji(i + 1)),
style=ButtonStyles.PRIMARY,
custom_id=f"poll-vote:{poll_entry.id}:{i}",
)
)
buttons.append(
Button(
label="Delete",
style=ButtonStyles.DANGER,
custom_id=f"poll-delete:{poll_entry.id}",
)
)
embed = generate_poll_embed(
title,
options,
len(options) * [0],
multiple_choice=multiple_choice,
expires=duration,
)
poll_message: Message = await modal_ctx.send(
embed=embed,
components=spread_to_rows(*buttons),
)
poll_entry.message_id = poll_message.id
poll_entry.channel_id = poll_message.channel.id
poll_entry.save()
@listen(events.Button)
async def on_button(self, button: events.Button):
ctx = button.context
await ctx.defer(ephemeral=True)
if ctx.custom_id.startswith("poll-vote:"):
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
return
if poll_entry.expires and datetime.now() > poll_entry.expires:
return
if not poll_entry.multiple_choice:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
if votes_q.count() > 1:
for vote in votes_q:
vote.delete().execute()
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
elif votes_q.count() == 1:
if int(votes_q[0].option) == int(option_num):
votes_q[0].delete_instance()
await ctx.send("You have removed your vote.")
else:
votes_q[0].option = option_num
votes_q[0].save()
await ctx.send("You have changed your vote.")
else:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
else:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
exists = False
for vote in votes_q:
if int(vote.option) == (option_num):
exists = True
vote.delete_instance()
await ctx.send("You have removed your vote.")
break
if not exists:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
votes_q: List[PollVotesModel] = (
PollVotesModel.select(
PollVotesModel.poll_id,
PollVotesModel.option,
fn.COUNT(PollVotesModel.option).alias("count"),
)
.where(PollVotesModel.poll_id == poll_id)
.group_by(PollVotesModel.option)
.order_by(PollVotesModel.option)
)
# This is such absolutely an awful way to do this. I'm sorry.
# It's the cost of not adding the options in a separate table, I guess.
# Anyway this just gets the votes for each option, and adds them to
# the list `votes`. They're in the same order as the options.
options: PollOptions = json.loads(poll_entry.options)
votes = len(options) * [0]
for i, vote in enumerate(votes_q):
votes[int(vote.option)] = vote.count
embed = generate_poll_embed(
poll_entry.title,
options,
votes,
multiple_choice=poll_entry.multiple_choice,
expires=poll_entry.expires,
)
await ctx.message.edit(embed=embed)
elif ctx.custom_id.startswith("poll-delete:"):
poll_id = ctx.custom_id.split(":", 1)[1]
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
await ctx.send("That poll doesn't exist.")
return
if (
not ctx.author.id == int(poll_entry.author_id)
or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES)
):
await ctx.send("You don't have permission to delete that poll.")
return
poll_entry.delete_instance()
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_id).execute()
await ctx.message.delete()
await ctx.send("Poll deleted.")
@Task.create(IntervalTrigger(minutes=1))
async def poll_expiry_check(self):
logging.info("Checking for expired polls.")
now = datetime.now()
polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now)
for poll_entry in polls_q:
channel: GuildText = await self.client.fetch_channel(int(poll_entry.channel_id))
if not channel:
continue
message: Message = await channel.fetch_message(int(poll_entry.message_id))
if not message:
continue
await message.edit(components=[])
def setup(client: Client):
PollsModel.create_table()
PollVotesModel.create_table()
Polls(client)
logging.info("Polls extension loaded")