Heimdallr/commands/polls.py

397 lines
12 KiB
Python
Raw Normal View History

2022-08-05 13:19:08 +02:00
from datetime import datetime, timedelta
import logging
import json
from typing import Dict, List, Tuple
from naff import (
Client,
Extension,
slash_command,
slash_option,
InteractionContext,
OptionTypes,
Permissions,
Modal,
ShortText,
ParagraphText,
ModalContext,
Button,
ButtonStyles,
Embed,
spread_to_rows,
Message,
listen,
Task,
IntervalTrigger,
GuildText,
)
from naff.api import events
from database import Polls as PollsModel, PollVotes as PollVotesModel
from peewee import fn
PollOptions = List[Tuple[str | None, str]]
def datetime_to_discord_time(dt: datetime) -> str:
t = dt.strftime("%s")
return f"<t:{int(t)}:R>"
def generate_bar(num: int, total: int, length: int = 10) -> str:
full_char = "\u2593"
empty_char = "\u2591"
if total <= 0:
return length * empty_char
result = round(num / total * length)
percent = num / total * 100
return (
result * full_char
+ (length - result) * empty_char
+ f" ({percent: 3.1f}%)".replace(".0", "")
)
def generate_poll_embed(
title: str,
options: PollOptions,
votes: List[int],
*,
multiple_choice: bool = False,
expires: datetime = None,
) -> Embed:
data = []
if multiple_choice:
data.append("\u2022 Multiple choice")
if expires:
data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}")
embed = Embed(
title=title,
description=("\n".join(data) if data else None),
)
sum_votes = sum(votes)
for i, (emoji, option) in enumerate(options):
embed.add_field(
name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**",
value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes",
inline=False,
)
return embed
def num_to_emoji(num: int) -> str:
match num:
case 0:
return "0"
case 1:
return "1"
case 2:
return "2"
case 3:
return "3"
case 4:
return "4"
case 5:
return "5"
case 6:
return "6"
case 7:
return "7"
case 8:
return "8"
case 9:
return "9"
case 10:
return "🔟"
case _:
raise ValueError(f"Invalid number: `num` must be 0 <= num <= 10.")
class Polls(Extension):
def __init__(self, client: Client):
self.client = client
@listen(events.Ready)
async def on_ready(self):
await self.poll_expiry_check()
self.poll_expiry_check.start()
@slash_command(
name="polls",
description="Polls",
sub_cmd_name="create",
sub_cmd_description="Create a poll",
dm_permission=False,
default_member_permissions=Permissions.SEND_MESSAGES,
)
@slash_option(
name="title",
description="Title of the poll",
required=True,
opt_type=OptionTypes.STRING,
)
@slash_option(
name="duration",
description="Duration of the poll in minutes",
required=False,
opt_type=OptionTypes.INTEGER,
)
@slash_option(
name="multiple-choice",
description="If users can vote for multiple options",
required=False,
opt_type=OptionTypes.BOOLEAN,
)
async def create_poll(
self,
ctx: InteractionContext,
*,
title: str,
duration: int | None = None,
multiple_choice: bool = False,
):
modal = Modal(
title="Creating poll",
components=[
ShortText(
custom_id="title",
label="Title",
value=title,
required=True,
max_length=120,
),
ParagraphText(
custom_id="options",
label="Poll options",
placeholder=(
"Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure"
),
required=True,
max_length=1200,
),
],
)
await ctx.send_modal(modal)
modal_ctx: ModalContext = await self.client.wait_for_modal(
modal=modal, author=ctx.author
)
duration: datetime | None = (
(datetime.now() + timedelta(minutes=duration)) if duration else None
)
title = modal_ctx.responses["title"]
options: PollOptions = []
for i, option in enumerate(
modal_ctx.responses["options"].replace("-", "", 1).split("\n-")
):
if option == "":
continue
parts = option.split(":", 1)
if len(parts) == 1:
options.append((None, parts[0].strip()))
else:
options.append((parts[0].strip(), parts[1].strip()))
if len(options) > 10:
await modal_ctx.send("You can only have up to 10 options.", ephemeral=True)
return
if len(options) < 2:
await modal_ctx.send("You must have at least 2 options.", ephemeral=True)
return
poll_entry: PollsModel = PollsModel.create(
guild_id=ctx.guild.id,
author_id=ctx.author.id,
title=title,
options=json.dumps(options),
no_options=len(options),
multiple_choice=multiple_choice,
expires=duration,
)
buttons: List[Button] = []
for i, option in enumerate(options):
buttons.append(
Button(
emoji=(option[0] or num_to_emoji(i + 1)),
style=ButtonStyles.PRIMARY,
custom_id=f"poll-vote:{poll_entry.id}:{i}",
)
)
buttons.append(
Button(
label="Delete",
style=ButtonStyles.DANGER,
custom_id=f"poll-delete:{poll_entry.id}",
)
)
embed = generate_poll_embed(
title,
options,
len(options) * [0],
multiple_choice=multiple_choice,
expires=duration,
)
poll_message: Message = await modal_ctx.send(
embed=embed,
components=spread_to_rows(*buttons),
)
poll_entry.message_id = poll_message.id
poll_entry.channel_id = poll_message.channel.id
poll_entry.save()
@listen(events.Button)
async def on_button(self, button: events.Button):
ctx = button.context
await ctx.defer(ephemeral=True)
if ctx.custom_id.startswith("poll-vote:"):
poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1)
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
return
if poll_entry.expires and datetime.now() > poll_entry.expires:
return
if not poll_entry.multiple_choice:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
if votes_q.count() > 1:
for vote in votes_q:
vote.delete().execute()
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
elif votes_q.count() == 1:
if int(votes_q[0].option) == int(option_num):
votes_q[0].delete_instance()
await ctx.send("You have removed your vote.")
else:
votes_q[0].option = option_num
votes_q[0].save()
await ctx.send("You have changed your vote.")
else:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
else:
votes_q: List[PollVotesModel] = PollVotesModel.select().where(
PollVotesModel.poll_id == poll_id,
PollVotesModel.user_id == ctx.author.id,
)
exists = False
for vote in votes_q:
if int(vote.option) == (option_num):
exists = True
vote.delete_instance()
await ctx.send("You have removed your vote.")
break
if not exists:
PollVotesModel.create(
poll_id=poll_id,
user_id=ctx.author.id,
option=option_num,
)
await ctx.send("You have voted.")
votes_q: List[PollVotesModel] = (
PollVotesModel.select(
PollVotesModel.poll_id,
PollVotesModel.option,
fn.COUNT(PollVotesModel.option).alias("count"),
)
.where(PollVotesModel.poll_id == poll_id)
.group_by(PollVotesModel.option)
.order_by(PollVotesModel.option)
)
# This is such absolutely an awful way to do this. I'm sorry.
# It's the cost of not adding the options in a separate table, I guess.
# Anyway this just gets the votes for each option, and adds them to
# the list `votes`. They're in the same order as the options.
options: PollOptions = json.loads(poll_entry.options)
votes = len(options) * [0]
for i, vote in enumerate(votes_q):
votes[int(vote.option)] = vote.count
embed = generate_poll_embed(
poll_entry.title,
options,
votes,
multiple_choice=poll_entry.multiple_choice,
expires=poll_entry.expires,
)
await ctx.message.edit(embed=embed)
elif ctx.custom_id.startswith("poll-delete:"):
poll_id = ctx.custom_id.split(":", 1)[1]
poll_entry: PollsModel | None = PollsModel.get_or_none(
guild_id=ctx.guild.id, id=poll_id
)
if not poll_entry:
await ctx.send("That poll doesn't exist.")
return
if (
not ctx.author.id == int(poll_entry.author_id)
or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES)
):
await ctx.send("You don't have permission to delete that poll.")
return
poll_entry.delete_instance()
PollVotesModel.delete().where(PollVotesModel.poll_id == poll_id).execute()
await ctx.message.delete()
await ctx.send("Poll deleted.")
@Task.create(IntervalTrigger(minutes=1))
async def poll_expiry_check(self):
logging.info("Checking for expired polls.")
now = datetime.now()
polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now)
for poll_entry in polls_q:
channel: GuildText = await self.client.fetch_channel(int(poll_entry.channel_id))
if not channel:
continue
message: Message = await channel.fetch_message(int(poll_entry.message_id))
if not message:
continue
await message.edit(components=[])
def setup(client: Client):
PollsModel.create_table()
PollVotesModel.create_table()
Polls(client)
logging.info("Polls extension loaded")