websocket_manager.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from fastapi import WebSocket
  2. from typing import Dict, List, Set
  3. import json
  4. class ConnectionManager:
  5. def __init__(self):
  6. self.active_connections: Dict[WebSocket, Dict] = {}
  7. self.token_to_connection: Dict[str, WebSocket] = {}
  8. self.game_connections: Dict[int, Set[WebSocket]] = {}
  9. async def connect(self, websocket: WebSocket):
  10. await websocket.accept()
  11. self.active_connections[websocket] = {
  12. "token": None,
  13. "game_id": None
  14. }
  15. async def disconnect(self, websocket: WebSocket):
  16. if websocket in self.active_connections:
  17. conn_data = self.active_connections[websocket]
  18. if conn_data["token"] in self.token_to_connection:
  19. del self.token_to_connection[conn_data["token"]]
  20. if conn_data["game_id"] in self.game_connections:
  21. if websocket in self.game_connections[conn_data["game_id"]]:
  22. self.game_connections[conn_data["game_id"]].remove(websocket)
  23. del self.active_connections[websocket]
  24. async def disconnect_all(self):
  25. for websocket in list(self.active_connections.keys()):
  26. await self.disconnect(websocket)
  27. def register_user(self, websocket: WebSocket, token: str, game_id: int):
  28. if websocket in self.active_connections:
  29. self.active_connections[websocket] = {
  30. "token": token,
  31. "game_id": game_id
  32. }
  33. self.token_to_connection[token] = websocket
  34. if game_id not in self.game_connections:
  35. self.game_connections[game_id] = set()
  36. self.game_connections[game_id].add(websocket)
  37. async def send_personal_message(self, token: str, message: dict):
  38. if token in self.token_to_connection:
  39. websocket = self.token_to_connection[token]
  40. await websocket.send_text(json.dumps(message))
  41. async def broadcast_to_game(self, game_id: int, message: dict):
  42. if game_id in self.game_connections:
  43. for websocket in self.game_connections[game_id]:
  44. await websocket.send_text(json.dumps(message))
  45. async def broadcast(self, message: dict):
  46. for connection in self.active_connections:
  47. await connection.send_text(json.dumps(message))