utils.py 20 KB


  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # utils.py
  4. # Copyright (C) 2018-2021 github.com/googlehosts Group:Z
  5. #
  6. # This module is part of googlehosts/telegram-repeater and is released under
  7. # the AGPL v3 License: https://www.gnu.org/licenses/agpl-3.0.txt
  8. #
  9. # This program is free software: you can redistribute it and/or modify
  10. # it under the terms of the GNU Affero General Public License as published by
  11. # the Free Software Foundation, either version 3 of the License, or
  12. # any later version.
  13. #
  14. # This program is distributed in the hope that it will be useful,
  15. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  16. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  17. # GNU Affero General Public License for more details.
  18. #
  19. # You should have received a copy of the GNU Affero General Public License
  20. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  21. from __future__ import annotations
  22. import asyncio
  23. import concurrent.futures
  24. import logging
  25. import random
  26. import string
  27. import time
  28. import traceback
  29. import warnings
  30. from configparser import ConfigParser
  31. from dataclasses import dataclass
  32. from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union
  33. import asyncpg
  34. from pyrogram import Client
  35. from pyrogram.errors import FloodWait
  36. from pyrogram.types import (InlineKeyboardButton, InlineKeyboardMarkup,
  37. Message, MessageEntity, User)
  38. logger = logging.getLogger(__name__)
  39. logger.setLevel(logging.INFO)
  40. _FixedDataType = TypeVar('_FixedDataType', str, bool, int)
  41. class TextParser:
  42. class BuildMessage:
  43. def __init__(self, msg: Message):
  44. # NOTE: Remove special handling code in the official version
  45. self.text: bytes = (msg.text if msg.text else msg.caption if msg.caption else '').encode('utf-16-le')
  46. self.chat_id: int = msg.chat.id
  47. self.entities: List[MessageEntity] = msg.entities if msg.text else msg.caption_entities
  48. self.user_name, self.user_id = TextParser.UserName(msg.from_user).get_name_id()
  49. self.message_id: int = msg.message_id
  50. try:
  51. self.forward_from: str = msg.forward_from_chat.title if msg.forward_from_chat else \
  52. ('DELETED' if msg.forward_from.is_deleted else (msg.forward_from.first_name + (' {}'.format(
  53. msg.forward_from.last_name) if msg.forward_from.last_name else ''))) if msg.forward_from else msg.forward_sender_name if msg.forward_sender_name else ''
  54. except TypeError:
  55. print(msg)
  56. self.forward_from = 'Error: unable to get the name of the account you wish to forward from'
  57. self.forward_fom_id: Optional[
  58. int] = msg.forward_from_chat.id if msg.forward_from_chat else msg.forward_from.id if msg.forward_from else None
  59. class UserName:
  60. def __init__(self, user: User):
  61. self.first_name: str = user.first_name
  62. self.last_name: str = user.last_name if user.last_name else ''
  63. self.full_name: str = user.first_name if self.last_name == '' else ' '.join(
  64. (self.first_name, self.last_name))
  65. self.id: int = user.id
  66. self.user: User = user
  67. def get_name_id(self) -> Tuple[str, int]:
  68. return self.full_name, self.id
  69. def __str__(self) -> str:
  70. return self.full_name
  71. _dict = {
  72. 'italic': ('i', 'i'),
  73. 'bold': ('b', 'b'),
  74. 'code': ('code', 'code'),
  75. 'pre': ('pre', 'pre'),
  76. 'text_link': ('a href="{}"', 'a'),
  77. 'strike': ('del', 'del'),
  78. 'underline': ('u', 'u'),
  79. 'text_mention': ('a href=tg://user?id={}', 'a')
  80. }
  81. filter_keyword = tuple(key for key, _ in _dict.items())
  82. def __init__(self):
  83. self._msg: Message = None
  84. self.parsed_msg: str = ''
  85. def parse_html_msg(self) -> str:
  86. result = []
  87. tag_stack = []
  88. # self.text = msg['text'].encode(_ENCODE)
  89. if self._msg.entities is None:
  90. return self._msg.text.decode('utf-16-le')
  91. start_pos = set(_entity.offset * 2 for _entity in self._msg.entities if _entity.type in self.filter_keyword)
  92. if not len(start_pos):
  93. return self._msg.text.decode('utf-16-le')
  94. # print(start_pos)
  95. _close_tag_pos = -1
  96. _close_tag = ''
  97. _last_cut = 0
  98. for _pos in range(len(self._msg.text) + 1):
  99. while _close_tag_pos == _pos:
  100. result.append(self._msg.text[_last_cut:_pos])
  101. _last_cut = _pos
  102. result.append(f'</{_close_tag}>'.encode('utf-16-le'))
  103. if not len(tag_stack):
  104. break
  105. _close_tag, _close_tag_pos = tag_stack.pop()
  106. if _pos in start_pos:
  107. result.append(self._msg.text[_last_cut:_pos])
  108. _last_cut = _pos
  109. for _entity in self._msg.entities:
  110. if _entity.offset * 2 == _pos:
  111. format_value = _entity.url
  112. if format_value is None and _entity.user:
  113. format_value = _entity.user.id
  114. result.append(f'<{self._dict[_entity["type"]][0]}>'.format(format_value).encode('utf-16-le'))
  115. tag_stack.append((self._dict[_entity.type][1], (_entity.offset + _entity.length) * 2))
  116. if _close_tag_pos <= _pos:
  117. _close_tag, _close_tag_pos = tag_stack.pop()
  118. result.append(self._msg.text[_last_cut:])
  119. return b''.join(result).decode('utf-16-le')
  120. def parse_main(self) -> str:
  121. return self.parse_html_msg()
  122. def split_offset(self) -> str:
  123. return self.parsed_msg
  124. def get_full_message(self) -> str:
  125. return ''.join(('<b>',
  126. self._msg.user_name[:30],
  127. ' (\u21a9 {})'.format(self._msg.forward_from[:30]) if self._msg.forward_from != '' else '',
  128. '</b>',
  129. '<a href="https://t.me/c/',
  130. str(-self._msg.chat_id - 1000000000000),
  131. '/',
  132. str(self._msg.message_id),
  133. '">:</a> ',
  134. self.parsed_msg
  135. ))
  136. @staticmethod
  137. def parse_user_markdown(user_id: Union[int, str], user_name: Optional[str] = None) -> str:
  138. if user_name is None:
  139. user_name = str(user_id)
  140. return f'[{user_name}](tg://user?id={user_id})'
  141. @staticmethod
  142. def parse_user_html(user_id: int, user_name: Optional[str] = None) -> str:
  143. if user_name is None:
  144. user_name = str(user_id)
  145. return f'<a href="tg://user?id={user_id}">{user_name}</a>'
  146. @staticmethod
  147. def markdown_replace(name: str) -> str:
  148. for x in ('['):
  149. name = name.replace(x, ''.join(('\\', x)))
  150. return name
  151. class PgSQLdb:
  152. def __init__(
  153. self,
  154. host: str,
  155. port: int,
  156. user: str,
  157. password: str,
  158. db: str,
  159. ):
  160. self.logger: logging.Logger = logging.getLogger(__name__)
  161. self.logger.setLevel(logging.DEBUG)
  162. self.host: str = host
  163. self.port: int = port
  164. self.user: str = user
  165. self.password: str = password
  166. self.db: str = db
  167. self.execute_lock: asyncio.Lock = asyncio.Lock()
  168. self.pgsql_connection: asyncpg.pool.Pool = None
  169. self.last_execute_time: float = 0.0
  170. async def create_connect(self) -> None:
  171. self.pgsql_connection = await asyncpg.create_pool(
  172. host=self.host,
  173. port=self.port,
  174. user=self.user,
  175. password=self.password,
  176. database=self.db
  177. )
  178. @classmethod
  179. async def create(cls,
  180. host: str,
  181. port: int,
  182. user: str,
  183. password: str,
  184. db: str,
  185. ) -> 'PgSQLdb':
  186. self = cls(host, port, user, password, db)
  187. await self.create_connect()
  188. return self
  189. async def query(self, sql: str, *args: Optional[_FixedDataType]) -> List[asyncpg.Record]:
  190. async with self.pgsql_connection.acquire() as conn:
  191. return await conn.fetch(sql, *args)
  192. async def query1(self, sql: str, *args: Optional[_FixedDataType]) -> Optional[asyncpg.Record]:
  193. async with self.pgsql_connection.acquire() as conn:
  194. return await conn.fetchrow(sql, *args)
  195. async def execute(self, sql: str, *args: Union[Sequence[Tuple[_FixedDataType, ...]],
  196. Optional[_FixedDataType]], many: bool = False) -> None:
  197. async with self.pgsql_connection.acquire() as conn:
  198. if many:
  199. await conn.executemany(sql, *args)
  200. else:
  201. await conn.execute(sql, *args)
  202. async def close(self) -> None:
  203. await self.pgsql_connection.close()
  204. async def insert_ex(self, id1: int, id2: int, user_id: Optional[int] = None) -> None:
  205. await self.execute(
  206. '''INSERT INTO "msg_id" VALUES ($1, $2, CURRENT_TIMESTAMP, $3)''',
  207. id1, id2, user_id)
  208. async def insert(self, msg: Message, msg_2: Message) -> None:
  209. try:
  210. await self.insert_ex(msg.message_id, msg_2.message_id, msg.from_user.id)
  211. except:
  212. traceback.print_exc()
  213. await self.insert_ex(msg.message_id, msg_2.message_id)
  214. async def get_user_id(self, msg: Union[Message, int]) -> Optional[asyncpg.Record]:
  215. return await self.query1(
  216. '''SELECT "user_id" FROM "msg_id" WHERE "msg_id" = (
  217. SELECT "msg_id" FROM "msg_id" WHERE "target_id" = $1
  218. )''',
  219. (msg if isinstance(msg, int) else msg.reply_to_message.message_id))
  220. async def get_id(self, msg_id: int, reverse: bool = False) -> Optional[int]:
  221. r = await self.query1('{} = $1'.format('''SELECT "{}" FROM "msg_id" WHERE "{}"'''.format(
  222. *(('target_id', 'msg_id') if not reverse else ('msg_id', 'target_id')))), msg_id)
  223. return r['target_id' if not reverse else 'msg_id'] if r else None
  224. async def get_reply_id(self, msg: Message) -> Optional[int]:
  225. return await self.get_id(msg.reply_to_message.message_id) if msg.reply_to_message else None
  226. async def get_reply_id_reverse(self, msg: Message) -> Optional[int]:
  227. return await self.get_id(msg.reply_to_message.message_id, True) if msg.reply_to_message else None
  228. async def get_msg_name_history_channel_msg_id(self, msg: Message) -> int:
  229. return (await self.query1(
  230. '''SELECT "channel_msg_id" FROM "username" WHERE "user_id" = (
  231. SELECT "user_id" FROM "msg_id" WHERE "target_id" = $1
  232. )''',
  233. msg.reply_to_message.message_id))['channel_msg_id']
  234. async def insert_new_warn(self, user_id: int, msg: str, msg_id: Optional[int]) -> int:
  235. await self.execute('''INSERT INTO "reasons" ("user_id", "text", "msg_id") VALUES ($1, $2, $3)''',
  236. user_id, msg, msg_id)
  237. # FIXME:
  238. return (await self.query1("SELECT LAST_INSERT_ID()"))['LAST_INSERT_ID()']
  239. async def delete_warn_by_id(self, warn_id: int) -> None:
  240. await self.execute('''DELETE FROM "reasons" WHERE "user_id" = $1''', warn_id)
  241. async def query_warn_by_user(self, user_id: int) -> int:
  242. return (await self.query1('''SELECT COUNT(*) FROM "reasons" WHERE "user_id" = $1''', user_id))['count']
  243. async def query_warn_reason_by_id(self, reason_id: int) -> str:
  244. return (await self.query1('''SELECT "text" FROM "reasons" WHERE "id" = $1''', reason_id))['text']
  245. async def query_user_in_banlist(self, user_id: int) -> bool:
  246. return await self.query1('''SELECT * FROM "banlist" WHERE "id" = $1''', user_id) is not None
  247. async def insert_user_to_banlist(self, user_id: int) -> None:
  248. await self.execute('''INSERT INTO "banlist" ("id") VALUES ($1)''', user_id)
  249. class InviteLinkTracker:
  250. @dataclass
  251. class _UserTracker:
  252. message_id: int
  253. timestamp: float
  254. def __init__(self, client: Client, problem_set: dict, chat_id: int):
  255. self.client: Client = client
  256. self.chat_id: int = chat_id
  257. self.user_dict: Dict[int, InviteLinkTracker._UserTracker] = {}
  258. self.revoke_time: int = problem_set['configs']['revoke_time'] + 10
  259. self.join_group_msg: str = problem_set['messages']['success_msg']
  260. self.tricket_msg: str = problem_set['messages']['join_group_message']
  261. self.last_revoke_time: float = 0.0
  262. self.current_link: str = ''
  263. self.stop_event: asyncio.Event = asyncio.Event()
  264. self.future: Optional[concurrent.futures.Future] = None
  265. def start(self) -> concurrent.futures.Future:
  266. if self.future is not None:
  267. return self.future
  268. self.future = asyncio.run_coroutine_threadsafe(self._boost_run(), asyncio.get_event_loop())
  269. return self.future
  270. async def do_revoke(self) -> None:
  271. while True:
  272. try:
  273. self.current_link = await self.client.export_chat_invite_link(self.chat_id)
  274. break
  275. except FloodWait as e:
  276. logger.warning('Got Floodwait, wait for %d seconds', e.x)
  277. await asyncio.sleep(e.x)
  278. await self.revoke_users()
  279. self.last_revoke_time = time.time()
  280. async def revoke_users(self) -> None:
  281. current_time = time.time()
  282. pending_delete = []
  283. need_update_user = asyncio.Queue()
  284. for user_id, user_tracker in self.user_dict.items():
  285. if current_time - user_tracker.timestamp > self.revoke_time:
  286. pending_delete.append(user_id)
  287. else:
  288. need_update_user.put_nowait((user_id, user_tracker.message_id))
  289. for user_id in pending_delete:
  290. self.user_dict.pop(user_id, None)
  291. while not need_update_user.empty():
  292. await self.client.edit_message_reply_markup(*need_update_user.get_nowait(),
  293. reply_markup=self.generate_keyboard())
  294. del pending_delete, need_update_user, current_time
  295. def get(self) -> str:
  296. return self.current_link
  297. async def join(self, timeout: float = 0) -> None:
  298. if self.future is None:
  299. return
  300. if timeout > 0:
  301. while not self.future.done():
  302. for _ in range(int(timeout // .05)):
  303. if self.future.done():
  304. return
  305. await asyncio.sleep(.05)
  306. else:
  307. await asyncio.sleep(0)
  308. @property
  309. def is_alive(self) -> bool:
  310. return self.future is not None and not self.future.done()
  311. def request_stop(self) -> None:
  312. self.stop_event.set()
  313. def generate_keyboard(self) -> InlineKeyboardMarkup:
  314. return InlineKeyboardMarkup(
  315. inline_keyboard=[
  316. [
  317. InlineKeyboardButton(text='Join group', url=self.current_link)
  318. ]
  319. ]
  320. )
  321. async def send_link(self, chat_id: int, from_ticket: bool = False) -> None:
  322. self.user_dict.update(
  323. {
  324. chat_id: InviteLinkTracker._UserTracker(
  325. # NOTE: KNOWN ISSUE, IF NEVER CONTACT FROM THIS BOT
  326. (await self.client.send_message(
  327. chat_id,
  328. self.join_group_msg if from_ticket else self.tricket_msg,
  329. 'html',
  330. reply_markup=self.generate_keyboard()
  331. )).message_id,
  332. time.time()
  333. )
  334. }
  335. )
  336. async def _boost_run(self) -> None:
  337. # Wait start:
  338. while not self.client.is_connected:
  339. await asyncio.sleep(0.01)
  340. # Do revoke first. (init process)
  341. await self.do_revoke()
  342. while not self.stop_event.is_set():
  343. try:
  344. if self.user_dict:
  345. if time.time() - self.last_revoke_time > 30:
  346. await self.do_revoke()
  347. except:
  348. traceback.print_exc()
  349. else:
  350. if not self.stop_event.is_set():
  351. await asyncio.sleep(1)
  352. def get_random_string(length: int = 8) -> str:
  353. return ''.join(random.choices(string.ascii_lowercase, k=length))
  354. class AuthSystem:
  355. class_self = None
  356. def __init__(self, conn: PgSQLdb):
  357. self.conn = conn
  358. self.authed_user: List[int] = []
  359. self.non_ignore_user: List[int] = []
  360. self.whitelist: List[int] = []
  361. async def init(self, owner: Optional[int] = None) -> None:
  362. sql_obj = await self.conn.query('''SELECT "uid", "authorized", "muted", "whitelist" FROM "auth_user"''')
  363. self.authed_user = [row['uid'] for row in sql_obj if row['authorized']]
  364. self.non_ignore_user = [row['uid'] for row in sql_obj if not row['muted']]
  365. self.whitelist = [row['uid'] for row in sql_obj if row['whitelist']]
  366. if owner is not None and owner not in self.authed_user:
  367. self.authed_user.append(owner)
  368. @classmethod
  369. async def create(cls, conn: PgSQLdb, owner: Optional[int] = None) -> AuthSystem:
  370. self = cls(conn)
  371. try:
  372. await self.init(owner)
  373. except KeyError:
  374. logger.critical('Got key error', exc_info=True)
  375. return self
  376. def check_ex(self, user_id: int) -> bool:
  377. return user_id in self.authed_user
  378. async def add_user(self, user_id: Union[str, int]) -> None:
  379. user_id = int(user_id)
  380. self.authed_user.append(user_id)
  381. self.authed_user = list(set(self.authed_user))
  382. if await self.query_user(user_id) is not None:
  383. await self.update_user(user_id, 'authorized', True)
  384. else:
  385. await self.conn.execute('''INSERT INTO "auth_user" ("uid", "authorized") VALUES ($1, true)''', user_id)
  386. async def update_user(self, user_id: int, column_name: str, value: Union[str, bool]) -> None:
  387. if isinstance(value, str):
  388. warnings.warn('value should passed by bool instead', DeprecationWarning, 2)
  389. value = value == 'Y'
  390. await self.conn.execute('''UPDATE "auth_user" SET "{}" = $1 WHERE "uid" = $2'''.format(column_name),
  391. value, user_id)
  392. async def query_user(self, user_id: int) -> Optional[asyncpg.Record]:
  393. return await self.conn.query1('''SELECT * FROM "auth_user" WHERE "uid" = $1''', user_id)
  394. async def del_user(self, user_id: int) -> None:
  395. self.authed_user.remove(user_id)
  396. await self.update_user(user_id, 'authorized', False)
  397. def check_muted(self, user_id: int) -> bool:
  398. return user_id not in self.non_ignore_user
  399. async def unmute_user(self, user_id: int):
  400. self.non_ignore_user.append(user_id)
  401. self.non_ignore_user = list(set(self.non_ignore_user))
  402. await self.update_user(user_id, 'muted', False)
  403. async def mute_user(self, user_id: int) -> None:
  404. self.non_ignore_user.remove(user_id)
  405. await self.update_user(user_id, 'muted', True)
  406. def check(self, user_id: int) -> bool:
  407. return self.check_ex(user_id) and not self.check_muted(user_id)
  408. def check_full(self, user_id: int) -> bool:
  409. return self.check_ex(user_id) or user_id in self.whitelist
  410. async def mute_or_unmute(self, r: str, chat_id: int) -> None:
  411. if not self.check_ex(chat_id):
  412. return
  413. try:
  414. await (self.mute_user if r == 'off' else self.unmute_user)(chat_id)
  415. except ValueError:
  416. pass
  417. @staticmethod
  418. def get_instance() -> AuthSystem:
  419. if AuthSystem.class_self is None:
  420. raise RuntimeError('Instance not initialize')
  421. return AuthSystem.class_self
  422. @staticmethod
  423. async def initialize_instance(conn: PgSQLdb, owner: int = None) -> AuthSystem:
  424. AuthSystem.class_self = await AuthSystem.create(conn, owner)
  425. return AuthSystem.class_self
  426. def get_language() -> str:
  427. config = ConfigParser()
  428. config.read('config.ini')
  429. return config.get('i18n', 'language', fallback='en_US')