diff --git a/Heimdallr.py b/Heimdallr.py index 5116510..4dce5c5 100644 --- a/Heimdallr.py +++ b/Heimdallr.py @@ -133,4 +133,5 @@ if __name__ == "__main__": bot.load_extension("commands.quote") bot.load_extension("commands.infractions") bot.load_extension("commands.self_roles") + bot.load_extension("commands.polls") bot.start(getenv("DISCORD_TOKEN")) diff --git a/commands/gatekeep.py b/commands/gatekeep.py index 9c6bdab..b99613f 100644 --- a/commands/gatekeep.py +++ b/commands/gatekeep.py @@ -267,7 +267,7 @@ class Gatekeep(Extension): await ctx.send("Your captcha was incorrect.", ephemeral=True) return - gkc.delete() + gkc.delete_instance() await ctx.author.add_role(int(gk.gatekeep_approve_role)) await ctx.send( str(gk.gatekeep_approve_message).format(member=ctx.author, guild=ctx.guild), diff --git a/commands/polls.py b/commands/polls.py new file mode 100644 index 0000000..39634d8 --- /dev/null +++ b/commands/polls.py @@ -0,0 +1,396 @@ +from datetime import datetime, timedelta +import logging +import json +from typing import Dict, List, Tuple +from naff import ( + Client, + Extension, + slash_command, + slash_option, + InteractionContext, + OptionTypes, + Permissions, + Modal, + ShortText, + ParagraphText, + ModalContext, + Button, + ButtonStyles, + Embed, + spread_to_rows, + Message, + listen, + Task, + IntervalTrigger, + GuildText, +) +from naff.api import events + +from database import Polls as PollsModel, PollVotes as PollVotesModel +from peewee import fn + +PollOptions = List[Tuple[str | None, str]] + + +def datetime_to_discord_time(dt: datetime) -> str: + t = dt.strftime("%s") + return f"" + + +def generate_bar(num: int, total: int, length: int = 10) -> str: + full_char = "\u2593" + empty_char = "\u2591" + + if total <= 0: + return length * empty_char + + result = round(num / total * length) + percent = num / total * 100 + return ( + result * full_char + + (length - result) * empty_char + + f" ({percent: 3.1f}%)".replace(".0", "") + ) + + +def generate_poll_embed( + title: str, + options: PollOptions, + votes: List[int], + *, + multiple_choice: bool = False, + expires: datetime = None, +) -> Embed: + data = [] + if multiple_choice: + data.append("\u2022 Multiple choice") + if expires: + data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}") + + embed = Embed( + title=title, + description=("\n".join(data) if data else None), + ) + sum_votes = sum(votes) + + for i, (emoji, option) in enumerate(options): + embed.add_field( + name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**", + value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes", + inline=False, + ) + return embed + + +def num_to_emoji(num: int) -> str: + match num: + case 0: + return "0️⃣" + case 1: + return "1️⃣" + case 2: + return "2️⃣" + case 3: + return "3️⃣" + case 4: + return "4️⃣" + case 5: + return "5️⃣" + case 6: + return "6️⃣" + case 7: + return "7️⃣" + case 8: + return "8️⃣" + case 9: + return "9️⃣" + case 10: + return "🔟" + + case _: + raise ValueError(f"Invalid number: `num` must be 0 <= num <= 10.") + + +class Polls(Extension): + def __init__(self, client: Client): + self.client = client + + @listen(events.Ready) + async def on_ready(self): + await self.poll_expiry_check() + self.poll_expiry_check.start() + + @slash_command( + name="polls", + description="Polls", + sub_cmd_name="create", + sub_cmd_description="Create a poll", + dm_permission=False, + default_member_permissions=Permissions.SEND_MESSAGES, + ) + @slash_option( + name="title", + description="Title of the poll", + required=True, + opt_type=OptionTypes.STRING, + ) + @slash_option( + name="duration", + description="Duration of the poll in minutes", + required=False, + opt_type=OptionTypes.INTEGER, + ) + @slash_option( + name="multiple-choice", + description="If users can vote for multiple options", + required=False, + opt_type=OptionTypes.BOOLEAN, + ) + async def create_poll( + self, + ctx: InteractionContext, + *, + title: str, + duration: int | None = None, + multiple_choice: bool = False, + ): + modal = Modal( + title="Creating poll", + components=[ + ShortText( + custom_id="title", + label="Title", + value=title, + required=True, + max_length=120, + ), + ParagraphText( + custom_id="options", + label="Poll options", + placeholder=( + "Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure" + ), + required=True, + max_length=1200, + ), + ], + ) + + await ctx.send_modal(modal) + modal_ctx: ModalContext = await self.client.wait_for_modal( + modal=modal, author=ctx.author + ) + + duration: datetime | None = ( + (datetime.now() + timedelta(minutes=duration)) if duration else None + ) + + title = modal_ctx.responses["title"] + options: PollOptions = [] + for i, option in enumerate( + modal_ctx.responses["options"].replace("-", "", 1).split("\n-") + ): + if option == "": + continue + parts = option.split(":", 1) + if len(parts) == 1: + options.append((None, parts[0].strip())) + else: + options.append((parts[0].strip(), parts[1].strip())) + + if len(options) > 10: + await modal_ctx.send("You can only have up to 10 options.", ephemeral=True) + return + if len(options) < 2: + await modal_ctx.send("You must have at least 2 options.", ephemeral=True) + return + + poll_entry: PollsModel = PollsModel.create( + guild_id=ctx.guild.id, + author_id=ctx.author.id, + title=title, + options=json.dumps(options), + no_options=len(options), + multiple_choice=multiple_choice, + expires=duration, + ) + + buttons: List[Button] = [] + for i, option in enumerate(options): + buttons.append( + Button( + emoji=(option[0] or num_to_emoji(i + 1)), + style=ButtonStyles.PRIMARY, + custom_id=f"poll-vote:{poll_entry.id}:{i}", + ) + ) + + buttons.append( + Button( + label="Delete", + style=ButtonStyles.DANGER, + custom_id=f"poll-delete:{poll_entry.id}", + ) + ) + embed = generate_poll_embed( + title, + options, + len(options) * [0], + multiple_choice=multiple_choice, + expires=duration, + ) + poll_message: Message = await modal_ctx.send( + embed=embed, + components=spread_to_rows(*buttons), + ) + + poll_entry.message_id = poll_message.id + poll_entry.channel_id = poll_message.channel.id + poll_entry.save() + + @listen(events.Button) + async def on_button(self, button: events.Button): + ctx = button.context + await ctx.defer(ephemeral=True) + + if ctx.custom_id.startswith("poll-vote:"): + poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1) + + poll_entry: PollsModel | None = PollsModel.get_or_none( + guild_id=ctx.guild.id, id=poll_id + ) + if not poll_entry: + return + + if poll_entry.expires and datetime.now() > poll_entry.expires: + return + + if not poll_entry.multiple_choice: + votes_q: List[PollVotesModel] = PollVotesModel.select().where( + PollVotesModel.poll_id == poll_id, + PollVotesModel.user_id == ctx.author.id, + ) + if votes_q.count() > 1: + for vote in votes_q: + vote.delete().execute() + + PollVotesModel.create( + poll_id=poll_id, + user_id=ctx.author.id, + option=option_num, + ) + elif votes_q.count() == 1: + if int(votes_q[0].option) == int(option_num): + votes_q[0].delete_instance() + await ctx.send("You have removed your vote.") + else: + votes_q[0].option = option_num + votes_q[0].save() + await ctx.send("You have changed your vote.") + else: + PollVotesModel.create( + poll_id=poll_id, + user_id=ctx.author.id, + option=option_num, + ) + await ctx.send("You have voted.") + else: + votes_q: List[PollVotesModel] = PollVotesModel.select().where( + PollVotesModel.poll_id == poll_id, + PollVotesModel.user_id == ctx.author.id, + ) + + exists = False + for vote in votes_q: + if int(vote.option) == (option_num): + exists = True + vote.delete_instance() + await ctx.send("You have removed your vote.") + break + + if not exists: + PollVotesModel.create( + poll_id=poll_id, + user_id=ctx.author.id, + option=option_num, + ) + await ctx.send("You have voted.") + + votes_q: List[PollVotesModel] = ( + PollVotesModel.select( + PollVotesModel.poll_id, + PollVotesModel.option, + fn.COUNT(PollVotesModel.option).alias("count"), + ) + .where(PollVotesModel.poll_id == poll_id) + .group_by(PollVotesModel.option) + .order_by(PollVotesModel.option) + ) + + # This is such absolutely an awful way to do this. I'm sorry. + # It's the cost of not adding the options in a separate table, I guess. + # Anyway this just gets the votes for each option, and adds them to + # the list `votes`. They're in the same order as the options. + options: PollOptions = json.loads(poll_entry.options) + votes = len(options) * [0] + + for i, vote in enumerate(votes_q): + votes[int(vote.option)] = vote.count + + embed = generate_poll_embed( + poll_entry.title, + options, + votes, + multiple_choice=poll_entry.multiple_choice, + expires=poll_entry.expires, + ) + + await ctx.message.edit(embed=embed) + + elif ctx.custom_id.startswith("poll-delete:"): + poll_id = ctx.custom_id.split(":", 1)[1] + + poll_entry: PollsModel | None = PollsModel.get_or_none( + guild_id=ctx.guild.id, id=poll_id + ) + if not poll_entry: + await ctx.send("That poll doesn't exist.") + return + + if ( + not ctx.author.id == int(poll_entry.author_id) + or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES) + ): + await ctx.send("You don't have permission to delete that poll.") + return + + poll_entry.delete_instance() + PollVotesModel.delete().where(PollVotesModel.poll_id == poll_id).execute() + await ctx.message.delete() + await ctx.send("Poll deleted.") + + @Task.create(IntervalTrigger(minutes=1)) + async def poll_expiry_check(self): + logging.info("Checking for expired polls.") + now = datetime.now() + + polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now) + for poll_entry in polls_q: + channel: GuildText = await self.client.fetch_channel(int(poll_entry.channel_id)) + if not channel: + continue + + message: Message = await channel.fetch_message(int(poll_entry.message_id)) + if not message: + continue + + await message.edit(components=[]) + + + + +def setup(client: Client): + PollsModel.create_table() + PollVotesModel.create_table() + Polls(client) + logging.info("Polls extension loaded") diff --git a/database.py b/database.py index 9f8b74f..80751e8 100644 --- a/database.py +++ b/database.py @@ -15,6 +15,7 @@ from peewee import ( BooleanField, CompositeKey, ForeignKeyField, + IntegerField, ) @@ -225,4 +226,30 @@ class GatekeepCaptchas(Model): class Meta: primary_key = CompositeKey("guild_id", "user_id") table_name = "GatekeepCaptchas" + database = db + +class Polls(Model): + id = AutoField() + guild_id = BigIntegerField() + channel_id = BigIntegerField(null=True) + message_id = BigIntegerField(null=True) + author_id = BigIntegerField() + title = TextField() + options = TextField() + no_options = IntegerField() + multiple_choice = BooleanField() + expires = DateTimeField(null=True) + + class Meta: + table_name = "Polls" + database = db + +class PollVotes(Model): + id = AutoField() + poll_id = ForeignKeyField(Polls, to_field="id") + user_id = BigIntegerField() + option = IntegerField() + + class Meta: + table_name = "PollVotes" database = db \ No newline at end of file