# pylint: disable=not-an-iterable # pylint: disable=unsubscriptable-object # pylint: disable=logging-fstring-interpolation from datetime import datetime, timedelta import logging import json from typing import List, Tuple from naff import ( Client, Extension, slash_command, slash_option, InteractionContext, OptionTypes, Permissions, Modal, ShortText, ParagraphText, ModalContext, Button, ButtonStyles, Embed, spread_to_rows, Message, listen, Task, IntervalTrigger, GuildText, PartialEmoji, ) from naff.api import events from naff.client.errors import HTTPException from peewee import fn from database import Polls as PollsModel, PollVotes as PollVotesModel PollOptions = List[Tuple[str | None, str]] def datetime_to_discord_relative_time(dt: datetime) -> str: """Create a Discord relative time text from a datetime object.""" t = dt.strftime("%s") return f"" def generate_bar(num: int, total: int, length: int = 10) -> str: """Create a bar graph from a number and a total. Parameters ---------- num : int The current amount. total : int The total amount. length : int The amount of characters to use for the bar. """ full_char = "\u2593" empty_char = "\u2591" if total <= 0: return length * empty_char result = round(num / total * length) percent = num / total * 100 return ( result * full_char + (length - result) * empty_char + f" ({percent: 3.1f}%)".replace(".0", "") ) def generate_poll_embed( title: str, options: PollOptions, votes: List[int], *, multiple_choice: bool = False, expires: datetime = None, ) -> Embed: """Create a poll embed from a title, options and votes. Parameters ---------- title : str The title of the poll. options : PollOptions The options of the poll. votes : List[int] The votes of the poll, in the same order as the options. multiple_choice : bool, optional Whether the poll is multiple choice. Defaults to False. expires : datetime, optional The time at which the poll expires. Defaults to None. """ data = [] # \u2022 is a bullet point if multiple_choice: data.append("\u2022 Multiple choice") if expires: data.append(f"\u2022 Expiry: {datetime_to_discord_relative_time(expires)}") embed = Embed( title=title, description=("\n".join(data) if data else None), ) sum_votes = sum(votes) # Add a field for each option, with the vote count and a bar graph showing the percentage. for i, (emoji, option) in enumerate(options): embed.add_field( name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**", value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes", inline=False, ) return embed def num_to_emoji(num: int) -> str: # pylint: disable=too-many-return-statements match num: case 0: return "0️⃣" case 1: return "1️⃣" case 2: return "2️⃣" case 3: return "3️⃣" case 4: return "4️⃣" case 5: return "5️⃣" case 6: return "6️⃣" case 7: return "7️⃣" case 8: return "8️⃣" case 9: return "9️⃣" case 10: return "🔟" case _: raise ValueError("Invalid number: `num` must be 0 <= num <= 10.") class Polls(Extension): def __init__(self, client: Client): self.client = client @listen(events.Ready) async def on_ready(self): await self.poll_expiry_check() self.poll_expiry_check.start() # pylint: disable=no-member @slash_command( name="polls", description="Polls", sub_cmd_name="create", sub_cmd_description="Create a poll", dm_permission=False, default_member_permissions=Permissions.SEND_MESSAGES, ) @slash_option( name="title", description="Title of the poll", required=True, opt_type=OptionTypes.STRING, ) @slash_option( name="duration", description="Duration of the poll in minutes", required=False, opt_type=OptionTypes.INTEGER, ) @slash_option( name="multiple-choice", description="If users can vote for multiple options", required=False, opt_type=OptionTypes.BOOLEAN, ) async def create_poll( # pylint: disable=too-many-locals self, ctx: InteractionContext, *, title: str, duration: int | None = None, multiple_choice: bool = False, ): modal = Modal( title="Creating poll", components=[ ShortText( custom_id="title", label="Title", value=title, required=True, max_length=120, ), ParagraphText( custom_id="options", label="Poll options", placeholder=( "Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure" ), required=True, max_length=1200, ), ], ) # Display the modal and wait for the user to submit. await ctx.send_modal(modal) modal_ctx: ModalContext = await self.client.wait_for_modal( modal=modal, author=ctx.author ) # If the user set a duration for the poll, create a datetime for the time in the # future when the poll expires. duration: datetime | None = ( (datetime.now() + timedelta(minutes=duration)) if duration else None ) title = modal_ctx.responses["title"] options: PollOptions = [] # Get each option from the poll options. # We're stripping the first occurance of a dash, as it otherwise would be included. for i, option in enumerate( modal_ctx.responses["options"].replace("-", "", 1).split("\n-") ): # If the option is empty, skip it. if option == "": continue # Check if the option contains an emoji. parts = option.split(":", 1) if len(parts) == 1: options.append((None, parts[0].strip())) else: options.append((parts[0].strip(), parts[1].strip())) if len(options) > 10: await modal_ctx.send("You can only have up to 10 options.", ephemeral=True) return if len(options) < 2: await modal_ctx.send("You must have at least 2 options.", ephemeral=True) return poll_entry: PollsModel = PollsModel.create( guild_id=ctx.guild.id, author_id=ctx.author.id, title=title, options=json.dumps(options), no_options=len(options), multiple_choice=multiple_choice, expires=duration, ) # Create vote buttons for each option. buttons: List[Button] = [] for i, option in enumerate(options): try: emoji = PartialEmoji.from_str((option[0] or num_to_emoji(i + 1))) except ValueError: emoji = num_to_emoji(i + 1) buttons.append( Button( emoji=emoji, style=ButtonStyles.PRIMARY, custom_id=f"poll-vote:{poll_entry.id}:{i}", ) ) # Create a button to allow locking the poll. buttons.append( Button( emoji="🔒", label="Lock", style=ButtonStyles.DANGER, custom_id=f"poll-lock:{poll_entry.id}", ) ) embed = generate_poll_embed( title, options, len(options) * [0], multiple_choice=multiple_choice, expires=duration, ) try: poll_message: Message = await modal_ctx.send( embed=embed, components=spread_to_rows(*buttons), ) # Naive error handling. except HTTPException as e: logging.error(f"Error sending poll: {e}") await modal_ctx.send( # TODO: This should probably also include a sample for poll options. "Error during poll creation. NB: You can not use server-specific emojis", ephemeral=True, ) return poll_entry.message_id = poll_message.id poll_entry.channel_id = poll_message.channel.id poll_entry.save() @listen(events.Button) async def on_button(self, button: events.Button): #pylint: disable=too-many-branches,too-many-statements ctx = button.context # Ensure that the pressed button is a vote button. if ctx.custom_id.startswith("poll-vote:"): # Get the poll ID and the option index. poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1) poll_entry: PollsModel | None = PollsModel.get_or_none( guild_id=ctx.guild.id, id=poll_id ) if not poll_entry: return if poll_entry.expires and datetime.now() > poll_entry.expires: return await ctx.defer(ephemeral=True) if not poll_entry.multiple_choice: votes_q: List[PollVotesModel] = PollVotesModel.select().where( PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) # If the user somehow already has more than one vote, delete them. # This should never happen, but just in case. # Then, add the new vote. if votes_q.count() > 1: for vote in votes_q: vote.delete_instance() PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) elif votes_q.count() == 1: # If the vote is the current vote, delete it. if int(votes_q[0].option) == int(option_num): votes_q[0].delete_instance() await ctx.send("You have removed your vote.") # If it's not the current vote, change the vote to the new one. else: votes_q[0].option = option_num votes_q[0].save() await ctx.send("You have changed your vote.") #If the user has no votes, add a new vote. else: PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) await ctx.send("You have voted.") # If the poll is multiple choice else: votes_q: List[PollVotesModel] = PollVotesModel.select().where( PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) # If the user has already voted for this option, remove their vote. exists = False for vote in votes_q: if int(vote.option) == (option_num): exists = True vote.delete_instance() await ctx.send("You have removed your vote.") break # If the user has not voted for this option, add a new vote. if not exists: PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) await ctx.send("You have voted.") votes_q: List[PollVotesModel] = ( PollVotesModel.select( PollVotesModel.poll_id, PollVotesModel.option, fn.COUNT(PollVotesModel.option).alias("count"), ) .where(PollVotesModel.poll_id == poll_id) .group_by(PollVotesModel.option) .order_by(PollVotesModel.option) ) # This is such absolutely an awful way to do this. I'm sorry. # It's the cost of not adding the options in a separate table, I guess. # Anyway this just gets the votes for each option, and adds them to # the list `votes`. They're in the same order as the options. options: PollOptions = json.loads(poll_entry.options) votes = len(options) * [0] for vote in votes_q: votes[int(vote.option)] = vote.count embed = generate_poll_embed( poll_entry.title, options, votes, multiple_choice=poll_entry.multiple_choice, expires=poll_entry.expires, ) # Edit the message with the updated information. await ctx.message.edit(embed=embed) # If the "lock poll" button is pressed, lock the poll. elif ctx.custom_id.startswith("poll-lock:"): poll_id = ctx.custom_id.split(":", 1)[1] poll_entry: PollsModel | None = PollsModel.get_or_none( guild_id=ctx.guild.id, id=poll_id ) if not poll_entry: await ctx.send("That poll doesn't exist.") return # Ensure that the user is the poll creator, or can manage messages. if not ctx.author.id == int( poll_entry.author_id ) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES): await ctx.send("You don't have permission to lock that poll.") return # Set the poll to be expired to lock it. poll_entry.expires = datetime.now() - timedelta(minutes=1) poll_entry.save() # Force the "poll expiry check" task to run. await self.poll_expiry_check() await ctx.send("Poll locked.") # A task that runs each minute to check for expired polls. @Task.create(IntervalTrigger(minutes=1)) async def poll_expiry_check(self): logging.info("Checking for expired polls.") now = datetime.now() polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now) for poll_entry in polls_q: channel: GuildText = await self.client.fetch_channel( int(poll_entry.channel_id) ) if not channel: continue message: Message = await channel.fetch_message(int(poll_entry.message_id)) if not message: continue await message.edit(components=[]) # Delete associated database entries, as they will no longer be updated. PollVotesModel.delete().where(PollVotesModel.poll_id == poll_entry.id).execute() poll_entry.delete_instance() def setup(client: Client): PollsModel.create_table() PollVotesModel.create_table() Polls(client) logging.info("Polls extension loaded")