from datetime import datetime, timedelta import logging import json from typing import Dict, List, Tuple from naff import ( Client, Extension, slash_command, slash_option, InteractionContext, OptionTypes, Permissions, Modal, ShortText, ParagraphText, ModalContext, Button, ButtonStyles, Embed, spread_to_rows, Message, listen, Task, IntervalTrigger, GuildText, PartialEmoji, ) from naff.api import events from naff.client.errors import HTTPException from database import Polls as PollsModel, PollVotes as PollVotesModel from peewee import fn PollOptions = List[Tuple[str | None, str]] def datetime_to_discord_time(dt: datetime) -> str: t = dt.strftime("%s") return f"" def generate_bar(num: int, total: int, length: int = 10) -> str: full_char = "\u2593" empty_char = "\u2591" if total <= 0: return length * empty_char result = round(num / total * length) percent = num / total * 100 return ( result * full_char + (length - result) * empty_char + f" ({percent: 3.1f}%)".replace(".0", "") ) def generate_poll_embed( title: str, options: PollOptions, votes: List[int], *, multiple_choice: bool = False, expires: datetime = None, ) -> Embed: data = [] if multiple_choice: data.append("\u2022 Multiple choice") if expires: data.append(f"\u2022 Expiry: {datetime_to_discord_time(expires)}") embed = Embed( title=title, description=("\n".join(data) if data else None), ) sum_votes = sum(votes) for i, (emoji, option) in enumerate(options): embed.add_field( name=f"**{emoji if emoji else num_to_emoji(i+1)} {option}**", value=generate_bar(votes[i], sum_votes) + "\n" + f"{votes[i]} votes", inline=False, ) return embed def num_to_emoji(num: int) -> str: match num: case 0: return "0️⃣" case 1: return "1️⃣" case 2: return "2️⃣" case 3: return "3️⃣" case 4: return "4️⃣" case 5: return "5️⃣" case 6: return "6️⃣" case 7: return "7️⃣" case 8: return "8️⃣" case 9: return "9️⃣" case 10: return "🔟" case _: raise ValueError(f"Invalid number: `num` must be 0 <= num <= 10.") class Polls(Extension): def __init__(self, client: Client): self.client = client @listen(events.Ready) async def on_ready(self): await self.poll_expiry_check() self.poll_expiry_check.start() @slash_command( name="polls", description="Polls", sub_cmd_name="create", sub_cmd_description="Create a poll", dm_permission=False, default_member_permissions=Permissions.SEND_MESSAGES, ) @slash_option( name="title", description="Title of the poll", required=True, opt_type=OptionTypes.STRING, ) @slash_option( name="duration", description="Duration of the poll in minutes", required=False, opt_type=OptionTypes.INTEGER, ) @slash_option( name="multiple-choice", description="If users can vote for multiple options", required=False, opt_type=OptionTypes.BOOLEAN, ) async def create_poll( self, ctx: InteractionContext, *, title: str, duration: int | None = None, multiple_choice: bool = False, ): modal = Modal( title="Creating poll", components=[ ShortText( custom_id="title", label="Title", value=title, required=True, max_length=120, ), ParagraphText( custom_id="options", label="Poll options", placeholder=( "Add poll options here.\n\n" "- ✅: Yes\n" "- ❌: No\n" "- Unsure" ), required=True, max_length=1200, ), ], ) await ctx.send_modal(modal) modal_ctx: ModalContext = await self.client.wait_for_modal( modal=modal, author=ctx.author ) duration: datetime | None = ( (datetime.now() + timedelta(minutes=duration)) if duration else None ) title = modal_ctx.responses["title"] options: PollOptions = [] for i, option in enumerate( modal_ctx.responses["options"].replace("-", "", 1).split("\n-") ): if option == "": continue parts = option.split(":", 1) if len(parts) == 1: options.append((None, parts[0].strip())) else: options.append((parts[0].strip(), parts[1].strip())) if len(options) > 10: await modal_ctx.send("You can only have up to 10 options.", ephemeral=True) return if len(options) < 2: await modal_ctx.send("You must have at least 2 options.", ephemeral=True) return poll_entry: PollsModel = PollsModel.create( guild_id=ctx.guild.id, author_id=ctx.author.id, title=title, options=json.dumps(options), no_options=len(options), multiple_choice=multiple_choice, expires=duration, ) buttons: List[Button] = [] for i, option in enumerate(options): try: emoji = PartialEmoji.from_str((option[0] or num_to_emoji(i + 1))) except ValueError: emoji = num_to_emoji(i + 1) buttons.append( Button( emoji=emoji, style=ButtonStyles.PRIMARY, custom_id=f"poll-vote:{poll_entry.id}:{i}", ) ) buttons.append( Button( emoji="🔒", label="Lock", style=ButtonStyles.DANGER, custom_id=f"poll-lock:{poll_entry.id}", ) ) embed = generate_poll_embed( title, options, len(options) * [0], multiple_choice=multiple_choice, expires=duration, ) try: poll_message: Message = await modal_ctx.send( embed=embed, components=spread_to_rows(*buttons), ) except HTTPException as e: logging.error(f"Error sending poll: {e}") await modal_ctx.send( "Error during poll creation. NB: You can not use server-specific emojis", ephemeral=True, ) return poll_entry.message_id = poll_message.id poll_entry.channel_id = poll_message.channel.id poll_entry.save() @listen(events.Button) async def on_button(self, button: events.Button): ctx = button.context await ctx.defer(ephemeral=True) if ctx.custom_id.startswith("poll-vote:"): poll_id, option_num = ctx.custom_id.split(":", 1)[1].split(":", 1) poll_entry: PollsModel | None = PollsModel.get_or_none( guild_id=ctx.guild.id, id=poll_id ) if not poll_entry: return if poll_entry.expires and datetime.now() > poll_entry.expires: return if not poll_entry.multiple_choice: votes_q: List[PollVotesModel] = PollVotesModel.select().where( PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) if votes_q.count() > 1: for vote in votes_q: vote.delete().execute() PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) elif votes_q.count() == 1: if int(votes_q[0].option) == int(option_num): votes_q[0].delete_instance() await ctx.send("You have removed your vote.") else: votes_q[0].option = option_num votes_q[0].save() await ctx.send("You have changed your vote.") else: PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) await ctx.send("You have voted.") else: votes_q: List[PollVotesModel] = PollVotesModel.select().where( PollVotesModel.poll_id == poll_id, PollVotesModel.user_id == ctx.author.id, ) exists = False for vote in votes_q: if int(vote.option) == (option_num): exists = True vote.delete_instance() await ctx.send("You have removed your vote.") break if not exists: PollVotesModel.create( poll_id=poll_id, user_id=ctx.author.id, option=option_num, ) await ctx.send("You have voted.") votes_q: List[PollVotesModel] = ( PollVotesModel.select( PollVotesModel.poll_id, PollVotesModel.option, fn.COUNT(PollVotesModel.option).alias("count"), ) .where(PollVotesModel.poll_id == poll_id) .group_by(PollVotesModel.option) .order_by(PollVotesModel.option) ) # This is such absolutely an awful way to do this. I'm sorry. # It's the cost of not adding the options in a separate table, I guess. # Anyway this just gets the votes for each option, and adds them to # the list `votes`. They're in the same order as the options. options: PollOptions = json.loads(poll_entry.options) votes = len(options) * [0] for i, vote in enumerate(votes_q): votes[int(vote.option)] = vote.count embed = generate_poll_embed( poll_entry.title, options, votes, multiple_choice=poll_entry.multiple_choice, expires=poll_entry.expires, ) await ctx.message.edit(embed=embed) elif ctx.custom_id.startswith("poll-lock:"): poll_id = ctx.custom_id.split(":", 1)[1] poll_entry: PollsModel | None = PollsModel.get_or_none( guild_id=ctx.guild.id, id=poll_id ) if not poll_entry: await ctx.send("That poll doesn't exist.") return if not ctx.author.id == int( poll_entry.author_id ) or not ctx.author.has_permission(Permissions.MANAGE_MESSAGES): await ctx.send("You don't have permission to lock that poll.") return poll_entry.expires = datetime.now() - timedelta(minutes=1) poll_entry.save() await self.poll_expiry_check() await ctx.send("Poll locked.") @Task.create(IntervalTrigger(minutes=1)) async def poll_expiry_check(self): logging.info("Checking for expired polls.") now = datetime.now() polls_q: List[PollsModel] = PollsModel.select().where(PollsModel.expires < now) for poll_entry in polls_q: channel: GuildText = await self.client.fetch_channel( int(poll_entry.channel_id) ) if not channel: continue message: Message = await channel.fetch_message(int(poll_entry.message_id)) if not message: continue await message.edit(components=[]) def setup(client: Client): PollsModel.create_table() PollVotesModel.create_table() Polls(client) logging.info("Polls extension loaded")