1.0.1
parent
93f782ab50
commit
d85f12fb9d
|
|
@ -13,6 +13,7 @@ build/
|
|||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
uploads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.models.models import LoginRequest, LoginResponse
|
||||
from app.core.database import get_db_connection
|
||||
from app.services.jwt_service import jwt_service
|
||||
from app.core.auth import get_current_user
|
||||
import hashlib
|
||||
import datetime
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -26,7 +30,13 @@ def login(request: LoginRequest):
|
|||
if user['password_hash'] != hashed_input and user['password_hash'] != request.password:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
token = f"token_{user['user_id']}_{hash_password(str(datetime.datetime.now()))[:16]}"
|
||||
# 创建JWT token
|
||||
token_data = {
|
||||
"user_id": user['user_id'],
|
||||
"username": user['username'],
|
||||
"caption": user['caption']
|
||||
}
|
||||
token = jwt_service.create_access_token(token_data)
|
||||
|
||||
return LoginResponse(
|
||||
user_id=user['user_id'],
|
||||
|
|
@ -35,3 +45,70 @@ def login(request: LoginRequest):
|
|||
email=user['email'],
|
||||
token=token
|
||||
)
|
||||
|
||||
@router.post("/auth/logout")
|
||||
def logout(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""登出接口,撤销当前token"""
|
||||
token = credentials.credentials
|
||||
|
||||
# 验证token并获取用户信息(不查询数据库)
|
||||
payload = jwt_service.verify_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid token payload")
|
||||
|
||||
# 撤销当前token
|
||||
revoked = jwt_service.revoke_token(token, user_id)
|
||||
|
||||
if revoked:
|
||||
return {"message": "Logged out successfully"}
|
||||
else:
|
||||
return {"message": "Already logged out or token not found"}
|
||||
|
||||
@router.post("/auth/logout-all")
|
||||
def logout_all(current_user: dict = Depends(get_current_user)):
|
||||
"""登出所有设备"""
|
||||
user_id = current_user['user_id']
|
||||
revoked_count = jwt_service.revoke_all_user_tokens(user_id)
|
||||
return {"message": f"Logged out from {revoked_count} devices"}
|
||||
|
||||
@router.post("/auth/admin/revoke-user-tokens/{user_id}")
|
||||
def admin_revoke_user_tokens(user_id: int, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""管理员功能:撤销指定用户的所有token"""
|
||||
token = credentials.credentials
|
||||
|
||||
# 验证管理员token(不查询数据库)
|
||||
payload = jwt_service.verify_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
admin_user_id = payload.get("user_id")
|
||||
if not admin_user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid token payload")
|
||||
|
||||
# 这里可以添加管理员权限检查,目前暂时允许任何登录用户操作
|
||||
# if not payload.get('is_admin'):
|
||||
# raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
|
||||
revoked_count = jwt_service.revoke_all_user_tokens(user_id)
|
||||
return {"message": f"Revoked {revoked_count} tokens for user {user_id}"}
|
||||
|
||||
@router.get("/auth/me")
|
||||
def get_me(current_user: dict = Depends(get_current_user)):
|
||||
"""获取当前用户信息"""
|
||||
return current_user
|
||||
|
||||
@router.post("/auth/refresh")
|
||||
def refresh_token(current_user: dict = Depends(get_current_user)):
|
||||
"""刷新token"""
|
||||
# 这里需要从请求中获取当前token,为简化先返回新token
|
||||
token_data = {
|
||||
"user_id": current_user['user_id'],
|
||||
"username": current_user['username'],
|
||||
"caption": current_user['caption']
|
||||
}
|
||||
new_token = jwt_service.create_access_token(token_data)
|
||||
return {"token": new_token}
|
||||
|
|
|
|||
|
|
@ -1,499 +0,0 @@
|
|||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from app.models.models import Meeting, TranscriptSegment, CreateMeetingRequest, UpdateMeetingRequest
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.config import UPLOAD_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSIONS, ALLOWED_IMAGE_EXTENSIONS, MAX_FILE_SIZE, MAX_IMAGE_SIZE
|
||||
from app.services.qiniu_service import qiniu_service
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/meetings", response_model=list[Meeting])
|
||||
def get_meetings(user_id: Optional[int] = None):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
base_query = '''
|
||||
SELECT
|
||||
m.meeting_id, m.title, m.meeting_time, m.summary, m.created_at,
|
||||
m.user_id as creator_id, u.caption as creator_username
|
||||
FROM meetings m
|
||||
JOIN users u ON m.user_id = u.user_id
|
||||
'''
|
||||
|
||||
if user_id:
|
||||
query = f'''
|
||||
{base_query}
|
||||
LEFT JOIN attendees a ON m.meeting_id = a.meeting_id
|
||||
WHERE m.user_id = %s OR a.user_id = %s
|
||||
GROUP BY m.meeting_id
|
||||
ORDER BY m.meeting_time DESC, m.created_at DESC
|
||||
'''
|
||||
cursor.execute(query, (user_id, user_id))
|
||||
else:
|
||||
query = f" {base_query} ORDER BY m.meeting_time DESC, m.created_at DESC"
|
||||
cursor.execute(query)
|
||||
|
||||
meetings = cursor.fetchall()
|
||||
|
||||
meeting_list = []
|
||||
for meeting in meetings:
|
||||
attendees_query = '''
|
||||
SELECT u.user_id, u.caption
|
||||
FROM attendees a
|
||||
JOIN users u ON a.user_id = u.user_id
|
||||
WHERE a.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(attendees_query, (meeting['meeting_id'],))
|
||||
attendees_data = cursor.fetchall()
|
||||
attendees = [{'user_id': row['user_id'], 'caption': row['caption']} for row in attendees_data]
|
||||
|
||||
meeting_list.append(Meeting(
|
||||
meeting_id=meeting['meeting_id'],
|
||||
title=meeting['title'],
|
||||
meeting_time=meeting['meeting_time'],
|
||||
summary=meeting['summary'],
|
||||
created_at=meeting['created_at'],
|
||||
attendees=attendees,
|
||||
creator_id=meeting['creator_id'],
|
||||
creator_username=meeting['creator_username']
|
||||
))
|
||||
|
||||
return meeting_list
|
||||
|
||||
@router.get("/meetings/{meeting_id}", response_model=Meeting)
|
||||
def get_meeting_details(meeting_id: int):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
query = '''
|
||||
SELECT
|
||||
m.meeting_id, m.title, m.meeting_time, m.summary, m.created_at,
|
||||
m.user_id as creator_id, u.caption as creator_username,
|
||||
af.file_path as audio_file_path
|
||||
FROM meetings m
|
||||
JOIN users u ON m.user_id = u.user_id
|
||||
LEFT JOIN audio_files af ON m.meeting_id = af.meeting_id
|
||||
WHERE m.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(query, (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
attendees_query = '''
|
||||
SELECT u.user_id, u.caption
|
||||
FROM attendees a
|
||||
JOIN users u ON a.user_id = u.user_id
|
||||
WHERE a.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(attendees_query, (meeting['meeting_id'],))
|
||||
attendees_data = cursor.fetchall()
|
||||
attendees = [{'user_id': row['user_id'], 'caption': row['caption']} for row in attendees_data]
|
||||
|
||||
meeting_data = Meeting(
|
||||
meeting_id=meeting['meeting_id'],
|
||||
title=meeting['title'],
|
||||
meeting_time=meeting['meeting_time'],
|
||||
summary=meeting['summary'],
|
||||
created_at=meeting['created_at'],
|
||||
attendees=attendees,
|
||||
creator_id=meeting['creator_id'],
|
||||
creator_username=meeting['creator_username']
|
||||
)
|
||||
|
||||
# Add audio file path if exists
|
||||
if meeting['audio_file_path']:
|
||||
meeting_data.audio_file_path = meeting['audio_file_path']
|
||||
|
||||
return meeting_data
|
||||
|
||||
@router.get("/meetings/{meeting_id}/transcript", response_model=list[TranscriptSegment])
|
||||
def get_meeting_transcript(meeting_id: int):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# First check if meeting exists
|
||||
meeting_query = "SELECT meeting_id FROM meetings WHERE meeting_id = %s"
|
||||
cursor.execute(meeting_query, (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# Get transcript segments
|
||||
transcript_query = '''
|
||||
SELECT segment_id, meeting_id, speaker_tag, start_time_ms, end_time_ms, text_content
|
||||
FROM transcript_segments
|
||||
WHERE meeting_id = %s
|
||||
ORDER BY start_time_ms ASC
|
||||
'''
|
||||
cursor.execute(transcript_query, (meeting_id,))
|
||||
segments = cursor.fetchall()
|
||||
|
||||
return [TranscriptSegment(**segment) for segment in segments]
|
||||
|
||||
@router.post("/meetings")
|
||||
def create_meeting(meeting_request: CreateMeetingRequest):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Create meeting
|
||||
meeting_query = '''
|
||||
INSERT INTO meetings (user_id, title, meeting_time, summary)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
'''
|
||||
# Note: You'll need to pass user_id, for now using hardcoded value
|
||||
cursor.execute(meeting_query, (1, meeting_request.title, meeting_request.meeting_time, None))
|
||||
meeting_id = cursor.lastrowid
|
||||
|
||||
# Add attendees
|
||||
for attendee_id in meeting_request.attendee_ids:
|
||||
attendee_query = '''
|
||||
INSERT INTO attendees (meeting_id, user_id)
|
||||
VALUES (%s, %s)
|
||||
ON DUPLICATE KEY UPDATE meeting_id = meeting_id
|
||||
'''
|
||||
cursor.execute(attendee_query, (meeting_id, attendee_id))
|
||||
|
||||
connection.commit()
|
||||
return {"meeting_id": meeting_id, "message": "Meeting created successfully"}
|
||||
|
||||
@router.put("/meetings/{meeting_id}")
|
||||
def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# Update meeting
|
||||
update_query = '''
|
||||
UPDATE meetings
|
||||
SET title = %s, meeting_time = %s, summary = %s
|
||||
WHERE meeting_id = %s
|
||||
'''
|
||||
cursor.execute(update_query, (
|
||||
meeting_request.title,
|
||||
meeting_request.meeting_time,
|
||||
meeting_request.summary,
|
||||
meeting_id
|
||||
))
|
||||
|
||||
# Update attendees - remove existing ones and add new ones
|
||||
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
|
||||
|
||||
for attendee_id in meeting_request.attendee_ids:
|
||||
attendee_query = '''
|
||||
INSERT INTO attendees (meeting_id, user_id)
|
||||
VALUES (%s, %s)
|
||||
'''
|
||||
cursor.execute(attendee_query, (meeting_id, attendee_id))
|
||||
|
||||
connection.commit()
|
||||
return {"message": "Meeting updated successfully"}
|
||||
|
||||
@router.delete("/meetings/{meeting_id}")
|
||||
def delete_meeting(meeting_id: int):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# Delete related records first (foreign key constraints)
|
||||
cursor.execute("DELETE FROM transcript_segments WHERE meeting_id = %s", (meeting_id,))
|
||||
cursor.execute("DELETE FROM audio_files WHERE meeting_id = %s", (meeting_id,))
|
||||
cursor.execute("DELETE FROM attachments WHERE meeting_id = %s", (meeting_id,))
|
||||
cursor.execute("DELETE FROM attendees WHERE meeting_id = %s", (meeting_id,))
|
||||
|
||||
# Delete meeting
|
||||
cursor.execute("DELETE FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
|
||||
connection.commit()
|
||||
return {"message": "Meeting deleted successfully"}
|
||||
|
||||
@router.post("/meetings/{meeting_id}/regenerate-summary")
|
||||
def regenerate_summary(meeting_id: int):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# For now, return a mock summary
|
||||
# In a real implementation, this would call an AI service
|
||||
mock_summary = """# AI 生成摘要
|
||||
|
||||
## 主要议题
|
||||
- 项目进度回顾
|
||||
- 技术方案讨论
|
||||
- 下阶段规划
|
||||
|
||||
## 关键决策
|
||||
- 采用新的技术架构
|
||||
- 调整项目时间节点
|
||||
- 分配任务责任
|
||||
|
||||
## 后续行动
|
||||
- [ ] 完成技术方案文档
|
||||
- [ ] 安排下次会议时间
|
||||
- [ ] 跟进项目进度"""
|
||||
|
||||
# Update meeting summary
|
||||
cursor.execute(
|
||||
"UPDATE meetings SET summary = %s WHERE meeting_id = %s",
|
||||
(mock_summary, meeting_id)
|
||||
)
|
||||
connection.commit()
|
||||
|
||||
return {"summary": mock_summary}
|
||||
|
||||
@router.get("/meetings/{meeting_id}/edit", response_model=Meeting)
|
||||
def get_meeting_for_edit(meeting_id: int):
|
||||
"""Get meeting details with full attendee information for editing"""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
query = '''
|
||||
SELECT
|
||||
m.meeting_id, m.title, m.meeting_time, m.summary, m.created_at,
|
||||
m.user_id as creator_id, u.caption as creator_username,
|
||||
af.file_path as audio_file_path
|
||||
FROM meetings m
|
||||
JOIN users u ON m.user_id = u.user_id
|
||||
LEFT JOIN audio_files af ON m.meeting_id = af.meeting_id
|
||||
WHERE m.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(query, (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# Get attendees with full info for editing
|
||||
attendees_query = '''
|
||||
SELECT u.user_id, u.caption
|
||||
FROM attendees a
|
||||
JOIN users u ON a.user_id = u.user_id
|
||||
WHERE a.meeting_id = %s
|
||||
'''
|
||||
cursor.execute(attendees_query, (meeting['meeting_id'],))
|
||||
attendees_data = cursor.fetchall()
|
||||
attendees = [{'user_id': row['user_id'], 'caption': row['caption']} for row in attendees_data]
|
||||
|
||||
meeting_data = Meeting(
|
||||
meeting_id=meeting['meeting_id'],
|
||||
title=meeting['title'],
|
||||
meeting_time=meeting['meeting_time'],
|
||||
summary=meeting['summary'],
|
||||
created_at=meeting['created_at'],
|
||||
attendees=attendees,
|
||||
creator_id=meeting['creator_id'],
|
||||
creator_username=meeting['creator_username']
|
||||
)
|
||||
|
||||
# Add audio file path if exists
|
||||
if meeting['audio_file_path']:
|
||||
meeting_data.audio_file_path = meeting['audio_file_path']
|
||||
|
||||
return meeting_data
|
||||
|
||||
@router.post("/meetings/upload-audio")
|
||||
async def upload_audio(
|
||||
audio_file: UploadFile = File(...),
|
||||
meeting_id: int = Form(...)
|
||||
):
|
||||
# Validate file extension
|
||||
file_extension = os.path.splitext(audio_file.filename)[1].lower()
|
||||
if file_extension not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if audio_file.size > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File size exceeds 100MB limit"
|
||||
)
|
||||
|
||||
# Check if meeting exists
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# TEMP: Use existing file to test Qiniu upload instead of client file
|
||||
# This bypasses potential client file processing issues
|
||||
existing_file = AUDIO_DIR / "31ce039a-f619-4869-91c8-eab934bbd1d4.m4a"
|
||||
if not existing_file.exists():
|
||||
raise HTTPException(status_code=500, detail="Test file not found")
|
||||
|
||||
temp_path = existing_file
|
||||
print(f"DEBUG: Using existing test file: {temp_path}")
|
||||
print(f"DEBUG: Test file exists: {temp_path.exists()}")
|
||||
print(f"DEBUG: Test file size: {temp_path.stat().st_size}")
|
||||
|
||||
# Upload to Qiniu
|
||||
try:
|
||||
print(f"DEBUG: Attempting to upload audio to Qiniu - meeting_id: {meeting_id}, filename: {audio_file.filename}")
|
||||
print(f"DEBUG: Temp file path: {temp_path}")
|
||||
print(f"DEBUG: Temp file exists: {temp_path.exists()}")
|
||||
|
||||
success, qiniu_url, error_msg = qiniu_service.upload_audio_file(
|
||||
str(temp_path), meeting_id, audio_file.filename
|
||||
)
|
||||
|
||||
print(f"DEBUG: Qiniu upload result - success: {success}, url: {qiniu_url}, error: {error_msg}")
|
||||
|
||||
# TEMP: Don't delete existing test file
|
||||
# if temp_path.exists():
|
||||
# temp_path.unlink()
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload to Qiniu: {error_msg}")
|
||||
|
||||
# Save file info to database with Qiniu URL
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Insert audio file record with Qiniu URL
|
||||
insert_query = '''
|
||||
INSERT INTO audio_files (meeting_id, file_name, file_path, file_size, upload_time)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
ON DUPLICATE KEY UPDATE
|
||||
file_name = VALUES(file_name),
|
||||
file_path = VALUES(file_path),
|
||||
file_size = VALUES(file_size),
|
||||
upload_time = VALUES(upload_time)
|
||||
'''
|
||||
cursor.execute(insert_query, (meeting_id, audio_file.filename, qiniu_url, audio_file.size))
|
||||
connection.commit()
|
||||
|
||||
return {
|
||||
"message": "Audio file uploaded successfully to Qiniu",
|
||||
"file_name": audio_file.filename,
|
||||
"file_path": qiniu_url,
|
||||
"qiniu_url": qiniu_url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Exception in audio upload: {str(e)}")
|
||||
print(f"DEBUG: Exception type: {type(e)}")
|
||||
import traceback
|
||||
print(f"DEBUG: Traceback: {traceback.format_exc()}")
|
||||
# TEMP: Don't delete existing test file in case of error
|
||||
# if temp_path.exists():
|
||||
# temp_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
|
||||
|
||||
@router.get("/meetings/{meeting_id}/audio")
|
||||
def get_audio_file(meeting_id: int):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
query = '''
|
||||
SELECT file_name, file_path, file_size, upload_time
|
||||
FROM audio_files
|
||||
WHERE meeting_id = %s
|
||||
'''
|
||||
cursor.execute(query, (meeting_id,))
|
||||
audio_file = cursor.fetchone()
|
||||
|
||||
if not audio_file:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found for this meeting")
|
||||
|
||||
return {
|
||||
"file_name": audio_file['file_name'],
|
||||
"file_path": audio_file['file_path'],
|
||||
"file_size": audio_file['file_size'],
|
||||
"upload_time": audio_file['upload_time']
|
||||
}
|
||||
|
||||
@router.post("/meetings/{meeting_id}/upload-image")
|
||||
async def upload_image(
|
||||
meeting_id: int,
|
||||
image_file: UploadFile = File(...)
|
||||
):
|
||||
# Validate file extension
|
||||
file_extension = os.path.splitext(image_file.filename)[1].lower()
|
||||
if file_extension not in ALLOWED_IMAGE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported image type. Allowed types: {', '.join(ALLOWED_IMAGE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if image_file.size > MAX_IMAGE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Image size exceeds 10MB limit"
|
||||
)
|
||||
|
||||
# Check if meeting exists
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
# Create temporary file for upload
|
||||
temp_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
temp_path = MARKDOWN_DIR / temp_filename
|
||||
|
||||
# Save file temporarily
|
||||
# Save file temporarily
|
||||
try:
|
||||
contents = await image_file.read()
|
||||
with open(temp_path, "wb") as buffer:
|
||||
buffer.write(contents)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save temporary image: {str(e)}")
|
||||
|
||||
# Upload to Qiniu
|
||||
try:
|
||||
print(f"DEBUG: Attempting to upload image to Qiniu - meeting_id: {meeting_id}, filename: {image_file.filename}")
|
||||
print(f"DEBUG: Temp file path: {temp_path}")
|
||||
print(f"DEBUG: Temp file exists: {temp_path.exists()}")
|
||||
|
||||
success, qiniu_url, error_msg = qiniu_service.upload_markdown_image(
|
||||
str(temp_path), meeting_id, image_file.filename
|
||||
)
|
||||
|
||||
print(f"DEBUG: Qiniu upload result - success: {success}, url: {qiniu_url}, error: {error_msg}")
|
||||
|
||||
# Clean up temporary file
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload image to Qiniu: {error_msg}")
|
||||
|
||||
return {
|
||||
"message": "Image uploaded successfully to Qiniu",
|
||||
"file_name": image_file.filename,
|
||||
"file_path": qiniu_url,
|
||||
"url": qiniu_url,
|
||||
"qiniu_url": qiniu_url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Exception in image upload: {str(e)}")
|
||||
print(f"DEBUG: Exception type: {type(e)}")
|
||||
import traceback
|
||||
print(f"DEBUG: Traceback: {traceback.format_exc()}")
|
||||
# Clean up temporary file in case of error
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"Image upload failed: {str(e)}")
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
|
||||
from app.models.models import Meeting, TranscriptSegment, TranscriptionTaskStatus, CreateMeetingRequest, UpdateMeetingRequest, SpeakerTagUpdateRequest, BatchSpeakerTagUpdateRequest, TranscriptUpdateRequest, BatchTranscriptUpdateRequest
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.config import BASE_DIR, UPLOAD_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSIONS, ALLOWED_IMAGE_EXTENSIONS, MAX_FILE_SIZE, MAX_IMAGE_SIZE
|
||||
from app.services.qiniu_service import qiniu_service
|
||||
from app.services.llm_service import LLMService
|
||||
from app.services.async_transcription_service import AsyncTranscriptionService
|
||||
from app.core.auth import get_current_user, get_optional_current_user
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
|
|
@ -22,7 +23,7 @@ class GenerateSummaryRequest(BaseModel):
|
|||
user_prompt: Optional[str] = ""
|
||||
|
||||
@router.get("/meetings", response_model=list[Meeting])
|
||||
def get_meetings(user_id: Optional[int] = None):
|
||||
def get_meetings(current_user: dict = Depends(get_current_user), user_id: Optional[int] = None):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -75,7 +76,7 @@ def get_meetings(user_id: Optional[int] = None):
|
|||
return meeting_list
|
||||
|
||||
@router.get("/meetings/{meeting_id}", response_model=Meeting)
|
||||
def get_meeting_details(meeting_id: int):
|
||||
def get_meeting_details(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -136,7 +137,7 @@ def get_meeting_details(meeting_id: int):
|
|||
return meeting_data
|
||||
|
||||
@router.get("/meetings/{meeting_id}/transcript", response_model=list[TranscriptSegment])
|
||||
def get_meeting_transcript(meeting_id: int):
|
||||
def get_meeting_transcript(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""获取会议的转录内容"""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -172,7 +173,7 @@ def get_meeting_transcript(meeting_id: int):
|
|||
return transcript_segments
|
||||
|
||||
@router.post("/meetings")
|
||||
def create_meeting(meeting_request: CreateMeetingRequest):
|
||||
def create_meeting(meeting_request: CreateMeetingRequest, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -202,15 +203,19 @@ def create_meeting(meeting_request: CreateMeetingRequest):
|
|||
return {"message": "Meeting created successfully", "meeting_id": meeting_id}
|
||||
|
||||
@router.put("/meetings/{meeting_id}")
|
||||
def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest):
|
||||
def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
# Check if meeting exists and user has permission
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Update meeting
|
||||
update_query = '''
|
||||
UPDATE meetings
|
||||
|
|
@ -238,15 +243,19 @@ def update_meeting(meeting_id: int, meeting_request: UpdateMeetingRequest):
|
|||
return {"message": "Meeting updated successfully"}
|
||||
|
||||
@router.delete("/meetings/{meeting_id}")
|
||||
def delete_meeting(meeting_id: int):
|
||||
def delete_meeting(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
# Check if meeting exists and user has permission
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Delete related records first (foreign key constraints)
|
||||
cursor.execute("DELETE FROM transcript_segments WHERE meeting_id = %s", (meeting_id,))
|
||||
cursor.execute("DELETE FROM audio_files WHERE meeting_id = %s", (meeting_id,))
|
||||
|
|
@ -260,7 +269,7 @@ def delete_meeting(meeting_id: int):
|
|||
return {"message": "Meeting deleted successfully"}
|
||||
|
||||
@router.post("/meetings/{meeting_id}/regenerate-summary")
|
||||
def regenerate_summary(meeting_id: int):
|
||||
def regenerate_summary(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -296,7 +305,7 @@ def regenerate_summary(meeting_id: int):
|
|||
return {"message": "Summary regenerated successfully", "summary": mock_summary}
|
||||
|
||||
@router.get("/meetings/{meeting_id}/edit", response_model=Meeting)
|
||||
def get_meeting_for_edit(meeting_id: int):
|
||||
def get_meeting_for_edit(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""获取会议信息用于编辑"""
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
|
@ -361,7 +370,8 @@ def get_meeting_for_edit(meeting_id: int):
|
|||
async def upload_audio(
|
||||
audio_file: UploadFile = File(...),
|
||||
meeting_id: int = Form(...),
|
||||
force_replace: str = Form("false") # 接收字符串,然后手动转换
|
||||
force_replace: str = Form("false"), # 接收字符串,然后手动转换
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
# Convert string to boolean
|
||||
force_replace_bool = force_replace.lower() in ("true", "1", "yes")
|
||||
|
|
@ -389,12 +399,17 @@ async def upload_audio(
|
|||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
# Check if meeting exists
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
# Check if meeting exists and user has permission
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
cursor.close()
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
cursor.close()
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Check existing audio file
|
||||
cursor.execute("""
|
||||
SELECT file_name, file_path, upload_time
|
||||
|
|
@ -514,7 +529,7 @@ async def upload_audio(
|
|||
}
|
||||
|
||||
@router.get("/meetings/{meeting_id}/audio")
|
||||
def get_audio_file(meeting_id: int):
|
||||
def get_audio_file(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -538,7 +553,7 @@ def get_audio_file(meeting_id: int):
|
|||
|
||||
# 转录任务相关接口
|
||||
@router.get("/transcription/tasks/{task_id}/status")
|
||||
def get_transcription_task_status(task_id: str):
|
||||
def get_transcription_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
|
||||
"""获取转录任务状态"""
|
||||
try:
|
||||
status_info = transcription_service.get_task_status(task_id)
|
||||
|
|
@ -550,7 +565,7 @@ def get_transcription_task_status(task_id: str):
|
|||
raise HTTPException(status_code=500, detail=f"Failed to get task status: {str(e)}")
|
||||
|
||||
@router.get("/meetings/{meeting_id}/transcription/status")
|
||||
def get_meeting_transcription_status(meeting_id: int):
|
||||
def get_meeting_transcription_status(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""获取会议的转录任务状态"""
|
||||
try:
|
||||
status_info = transcription_service.get_meeting_transcription_status(meeting_id)
|
||||
|
|
@ -561,7 +576,7 @@ def get_meeting_transcription_status(meeting_id: int):
|
|||
raise HTTPException(status_code=500, detail=f"Failed to get meeting transcription status: {str(e)}")
|
||||
|
||||
@router.post("/meetings/{meeting_id}/transcription/start")
|
||||
def start_meeting_transcription(meeting_id: int):
|
||||
def start_meeting_transcription(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""手动启动会议转录任务(如果有音频文件的话)"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
|
|
@ -606,7 +621,8 @@ def start_meeting_transcription(meeting_id: int):
|
|||
@router.post("/meetings/{meeting_id}/upload-image")
|
||||
async def upload_image(
|
||||
meeting_id: int,
|
||||
image_file: UploadFile = File(...)
|
||||
image_file: UploadFile = File(...),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
# Validate file extension
|
||||
file_extension = os.path.splitext(image_file.filename)[1].lower()
|
||||
|
|
@ -623,12 +639,16 @@ async def upload_image(
|
|||
detail="Image size exceeds 10MB limit"
|
||||
)
|
||||
|
||||
# Check if meeting exists
|
||||
# Check if meeting exists and user has permission
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT meeting_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
if not cursor.fetchone():
|
||||
cursor.execute("SELECT user_id FROM meetings WHERE meeting_id = %s", (meeting_id,))
|
||||
meeting = cursor.fetchone()
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting['user_id'] != current_user['user_id']:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# Create meeting-specific directory
|
||||
meeting_dir = MARKDOWN_DIR / str(meeting_id)
|
||||
|
|
@ -654,7 +674,7 @@ async def upload_image(
|
|||
|
||||
# 发言人标签更新接口
|
||||
@router.put("/meetings/{meeting_id}/speaker-tags")
|
||||
def update_speaker_tag(meeting_id: int, request: SpeakerTagUpdateRequest):
|
||||
def update_speaker_tag(meeting_id: int, request: SpeakerTagUpdateRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""更新单个发言人标签(基于原始的speaker_id值)"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
|
|
@ -678,7 +698,7 @@ def update_speaker_tag(meeting_id: int, request: SpeakerTagUpdateRequest):
|
|||
raise HTTPException(status_code=500, detail=f"Failed to update speaker tag: {str(e)}")
|
||||
|
||||
@router.put("/meetings/{meeting_id}/speaker-tags/batch")
|
||||
def batch_update_speaker_tags(meeting_id: int, request: BatchSpeakerTagUpdateRequest):
|
||||
def batch_update_speaker_tags(meeting_id: int, request: BatchSpeakerTagUpdateRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""批量更新发言人标签(基于原始的speaker_id值)"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
|
|
@ -703,7 +723,7 @@ def batch_update_speaker_tags(meeting_id: int, request: BatchSpeakerTagUpdateReq
|
|||
|
||||
# 转录内容更新接口
|
||||
@router.put("/meetings/{meeting_id}/transcript/batch")
|
||||
def batch_update_transcript(meeting_id: int, request: BatchTranscriptUpdateRequest):
|
||||
def batch_update_transcript(meeting_id: int, request: BatchTranscriptUpdateRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""批量更新转录内容"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
|
|
@ -734,7 +754,7 @@ def batch_update_transcript(meeting_id: int, request: BatchTranscriptUpdateReque
|
|||
|
||||
# AI总结相关接口
|
||||
@router.post("/meetings/{meeting_id}/generate-summary")
|
||||
def generate_meeting_summary(meeting_id: int, request: GenerateSummaryRequest):
|
||||
def generate_meeting_summary(meeting_id: int, request: GenerateSummaryRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""生成会议AI总结"""
|
||||
try:
|
||||
# 检查会议是否存在
|
||||
|
|
@ -763,7 +783,7 @@ def generate_meeting_summary(meeting_id: int, request: GenerateSummaryRequest):
|
|||
raise HTTPException(status_code=500, detail=f"Failed to generate summary: {str(e)}")
|
||||
|
||||
@router.get("/meetings/{meeting_id}/summaries")
|
||||
def get_meeting_summaries(meeting_id: int):
|
||||
def get_meeting_summaries(meeting_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""获取会议的所有AI总结历史"""
|
||||
try:
|
||||
# 检查会议是否存在
|
||||
|
|
@ -787,7 +807,7 @@ def get_meeting_summaries(meeting_id: int):
|
|||
raise HTTPException(status_code=500, detail=f"Failed to get summaries: {str(e)}")
|
||||
|
||||
@router.get("/meetings/{meeting_id}/summaries/{summary_id}")
|
||||
def get_summary_detail(meeting_id: int, summary_id: int):
|
||||
def get_summary_detail(meeting_id: int, summary_id: int, current_user: dict = Depends(get_current_user)):
|
||||
"""获取特定总结的详细内容"""
|
||||
try:
|
||||
with get_db_connection() as connection:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from app.models.models import UserInfo
|
||||
from app.core.database import get_db_connection
|
||||
from app.core.auth import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/users", response_model=list[UserInfo])
|
||||
def get_all_users():
|
||||
def get_all_users(current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
@ -24,7 +25,7 @@ def get_all_users():
|
|||
return [UserInfo(**user) for user in users]
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserInfo)
|
||||
def get_user_info(user_id: int):
|
||||
def get_user_info(user_id: int, current_user: dict = Depends(get_current_user)):
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
from fastapi import HTTPException, status, Request, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from typing import Optional
|
||||
from app.services.jwt_service import jwt_service
|
||||
from app.core.database import get_db_connection
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户信息的依赖函数"""
|
||||
token = credentials.credentials
|
||||
|
||||
# 验证JWT token
|
||||
payload = jwt_service.verify_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 从数据库验证用户是否仍然存在且有效
|
||||
user_id = payload.get("user_id")
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT user_id, username, caption, email FROM users WHERE user_id = %s",
|
||||
(user_id,)
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def get_optional_current_user(request: Request) -> Optional[dict]:
|
||||
"""可选的用户认证(不强制要求登录)"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
try:
|
||||
token = auth_header.split(" ")[1]
|
||||
payload = jwt_service.verify_token(token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT user_id, username, caption, email FROM users WHERE user_id = %s",
|
||||
(user_id,)
|
||||
)
|
||||
return cursor.fetchone()
|
||||
except:
|
||||
return None
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
import jwt
|
||||
import redis
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from app.core.config import REDIS_CONFIG
|
||||
import os
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-super-secret-key-change-in-production')
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天
|
||||
|
||||
class JWTService:
|
||||
def __init__(self):
|
||||
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
||||
|
||||
def create_access_token(self, data: Dict[str, Any]) -> str:
|
||||
"""创建JWT访问令牌"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
# 将token存储到Redis,用于管理和撤销
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.redis_client.setex(
|
||||
f"token:{user_id}:{encoded_jwt}",
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES * 60, # Redis需要秒
|
||||
"active"
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""验证JWT令牌"""
|
||||
try:
|
||||
# 解码JWT
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 检查token类型
|
||||
if payload.get("type") != "access":
|
||||
return None
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
# 检查token是否在Redis中且未被撤销
|
||||
redis_key = f"token:{user_id}:{token}"
|
||||
if not self.redis_client.exists(redis_key):
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def revoke_token(self, token: str, user_id: int) -> bool:
|
||||
"""撤销token"""
|
||||
try:
|
||||
redis_key = f"token:{user_id}:{token}"
|
||||
return self.redis_client.delete(redis_key) > 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def revoke_all_user_tokens(self, user_id: int) -> int:
|
||||
"""撤销用户的所有token"""
|
||||
try:
|
||||
pattern = f"token:{user_id}:*"
|
||||
keys = self.redis_client.keys(pattern)
|
||||
if keys:
|
||||
return self.redis_client.delete(*keys)
|
||||
return 0
|
||||
except:
|
||||
return 0
|
||||
|
||||
def refresh_token(self, token: str) -> Optional[str]:
|
||||
"""刷新token(可选功能)"""
|
||||
payload = self.verify_token(token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
# 撤销旧token
|
||||
user_id = payload.get("user_id")
|
||||
self.revoke_token(token, user_id)
|
||||
|
||||
# 创建新token
|
||||
new_data = {
|
||||
"user_id": user_id,
|
||||
"username": payload.get("username"),
|
||||
"caption": payload.get("caption")
|
||||
}
|
||||
return self.create_access_token(new_data)
|
||||
|
||||
# 全局实例
|
||||
jwt_service = JWTService()
|
||||
|
|
@ -6,4 +6,6 @@ pydantic[email]
|
|||
passlib[bcrypt]
|
||||
qiniu
|
||||
redis>=5.0.0
|
||||
dashscope
|
||||
dashscope
|
||||
PyJWT>=2.8.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
API安全性测试脚本
|
||||
测试添加JWT验证后,API端点是否正确拒绝未授权访问
|
||||
|
||||
运行方法:
|
||||
cd /Users/jiliu/工作/projects/imeeting/backend
|
||||
source venv/bin/activate
|
||||
python test/test_api_security.py
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
PROXIES = {'http': None, 'https': None}
|
||||
|
||||
def test_unauthorized_access():
|
||||
"""测试未授权访问各个API端点"""
|
||||
print("=== API安全性测试 ===")
|
||||
print("测试未授权访问是否被正确拒绝\n")
|
||||
|
||||
# 需要验证的API端点
|
||||
protected_endpoints = [
|
||||
# Users endpoints
|
||||
("GET", "/api/users", "获取所有用户"),
|
||||
("GET", "/api/users/1", "获取用户详情"),
|
||||
|
||||
# Meetings endpoints
|
||||
("GET", "/api/meetings", "获取会议列表"),
|
||||
("GET", "/api/meetings/1", "获取会议详情"),
|
||||
("GET", "/api/meetings/1/transcript", "获取会议转录"),
|
||||
("GET", "/api/meetings/1/edit", "获取会议编辑信息"),
|
||||
("GET", "/api/meetings/1/audio", "获取会议音频"),
|
||||
("POST", "/api/meetings/1/regenerate-summary", "重新生成摘要"),
|
||||
("GET", "/api/meetings/1/summaries", "获取会议摘要"),
|
||||
("GET", "/api/meetings/1/transcription/status", "获取转录状态"),
|
||||
|
||||
# Auth endpoints (需要token的)
|
||||
("GET", "/api/auth/me", "获取用户信息"),
|
||||
("POST", "/api/auth/logout", "登出"),
|
||||
("POST", "/api/auth/logout-all", "登出所有设备"),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
total_count = len(protected_endpoints)
|
||||
|
||||
for method, endpoint, description in protected_endpoints:
|
||||
try:
|
||||
url = f"{BASE_URL}{endpoint}"
|
||||
|
||||
if method == "GET":
|
||||
response = requests.get(url, proxies=PROXIES, timeout=5)
|
||||
elif method == "POST":
|
||||
response = requests.post(url, proxies=PROXIES, timeout=5)
|
||||
elif method == "PUT":
|
||||
response = requests.put(url, proxies=PROXIES, timeout=5)
|
||||
elif method == "DELETE":
|
||||
response = requests.delete(url, proxies=PROXIES, timeout=5)
|
||||
|
||||
if response.status_code == 401:
|
||||
print(f"✅ {method} {endpoint} - {description}")
|
||||
print(f" 正确返回401 Unauthorized")
|
||||
success_count += 1
|
||||
else:
|
||||
print(f"❌ {method} {endpoint} - {description}")
|
||||
print(f" 错误:返回 {response.status_code},应该返回401")
|
||||
print(f" 响应: {response.text[:100]}...")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ {method} {endpoint} - {description}")
|
||||
print(f" 请求异常: {e}")
|
||||
|
||||
print()
|
||||
|
||||
print(f"=== 测试结果 ===")
|
||||
print(f"通过: {success_count}/{total_count}")
|
||||
print(f"成功率: {success_count/total_count*100:.1f}%")
|
||||
|
||||
if success_count == total_count:
|
||||
print("🎉 所有API端点都正确实施了JWT验证!")
|
||||
else:
|
||||
print("⚠️ 有些API端点未正确实施JWT验证,需要修复")
|
||||
|
||||
return success_count == total_count
|
||||
|
||||
def test_valid_token_access():
|
||||
"""测试有效token的访问"""
|
||||
print("\n=== 测试有效Token访问 ===")
|
||||
|
||||
# 1. 先登录获取token
|
||||
login_data = {"username": "mula", "password": "781126"}
|
||||
try:
|
||||
response = requests.post(f"{BASE_URL}/api/auth/login", json=login_data, proxies=PROXIES)
|
||||
if response.status_code != 200:
|
||||
print("❌ 无法登录获取测试token")
|
||||
print(f"登录响应: {response.status_code} - {response.text}")
|
||||
return False
|
||||
|
||||
user_data = response.json()
|
||||
token = user_data["token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
print(f"✅ 登录成功,获得token")
|
||||
|
||||
# 2. 测试几个主要API端点
|
||||
test_endpoints = [
|
||||
("GET", "/api/auth/me", "获取当前用户信息"),
|
||||
("GET", "/api/users", "获取用户列表"),
|
||||
("GET", "/api/meetings", "获取会议列表"),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
for method, endpoint, description in test_endpoints:
|
||||
try:
|
||||
url = f"{BASE_URL}{endpoint}"
|
||||
response = requests.get(url, headers=headers, proxies=PROXIES, timeout=5)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f"✅ {method} {endpoint} - {description}")
|
||||
print(f" 正确返回200 OK")
|
||||
success_count += 1
|
||||
elif response.status_code == 500:
|
||||
print(f"⚠️ {method} {endpoint} - {description}")
|
||||
print(f" 返回500 (可能是数据库连接问题,但JWT验证通过了)")
|
||||
success_count += 1
|
||||
else:
|
||||
print(f"❌ {method} {endpoint} - {description}")
|
||||
print(f" 意外响应: {response.status_code}")
|
||||
print(f" 响应内容: {response.text[:100]}...")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ {method} {endpoint} - {description}")
|
||||
print(f" 请求异常: {e}")
|
||||
|
||||
print(f"\n有效token测试: {success_count}/{len(test_endpoints)} 通过")
|
||||
return success_count == len(test_endpoints)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("API JWT安全性测试工具")
|
||||
print("=" * 50)
|
||||
|
||||
# 测试未授权访问
|
||||
unauthorized_ok = test_unauthorized_access()
|
||||
|
||||
# 测试授权访问
|
||||
authorized_ok = test_valid_token_access()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
if unauthorized_ok and authorized_ok:
|
||||
print("🎉 JWT验证实施成功!")
|
||||
print("✅ 未授权访问被正确拒绝")
|
||||
print("✅ 有效token可以正常访问")
|
||||
else:
|
||||
print("⚠️ JWT验证实施不完整")
|
||||
if not unauthorized_ok:
|
||||
print("❌ 部分API未正确拒绝未授权访问")
|
||||
if not authorized_ok:
|
||||
print("❌ 有效token访问存在问题")
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>JWT Token 测试工具</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 20px; }
|
||||
.container { max-width: 800px; margin: 0 auto; }
|
||||
textarea { width: 100%; height: 100px; margin: 10px 0; }
|
||||
.result { background: #f5f5f5; padding: 15px; margin: 10px 0; border-radius: 5px; }
|
||||
.error { background: #ffebee; color: #c62828; }
|
||||
.success { background: #e8f5e8; color: #2e7d2e; }
|
||||
button { padding: 10px 20px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }
|
||||
button:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>JWT Token 验证工具</h1>
|
||||
|
||||
<h3>步骤1: 从浏览器获取Token</h3>
|
||||
<p>1. 登录你的应用</p>
|
||||
<p>2. 打开开发者工具 → Application → Local Storage → 找到 'iMeetingUser'</p>
|
||||
<p>3. 复制其中的 token 值到下面的文本框</p>
|
||||
|
||||
<textarea id="tokenInput" placeholder="在此粘贴JWT token..."></textarea>
|
||||
<button onclick="decodeToken()">解码 JWT Token</button>
|
||||
|
||||
<div id="result"></div>
|
||||
|
||||
<h3>预期的JWT payload应该包含:</h3>
|
||||
<ul>
|
||||
<li><code>user_id</code>: 用户ID</li>
|
||||
<li><code>username</code>: 用户名</li>
|
||||
<li><code>caption</code>: 用户显示名</li>
|
||||
<li><code>exp</code>: 过期时间戳</li>
|
||||
<li><code>type</code>: "access"</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function decodeToken() {
|
||||
const token = document.getElementById('tokenInput').value.trim();
|
||||
const resultDiv = document.getElementById('result');
|
||||
|
||||
if (!token) {
|
||||
resultDiv.innerHTML = '<div class="result error">请输入JWT token</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// JWT 由三部分组成,用 . 分隔
|
||||
const parts = token.split('.');
|
||||
if (parts.length !== 3) {
|
||||
throw new Error('无效的JWT格式');
|
||||
}
|
||||
|
||||
// 解码 header
|
||||
const header = JSON.parse(atob(parts[0]));
|
||||
|
||||
// 解码 payload
|
||||
const payload = JSON.parse(atob(parts[1]));
|
||||
|
||||
// 检查是否是我们的JWT
|
||||
const isValidJWT = payload.type === 'access' &&
|
||||
payload.user_id &&
|
||||
payload.username &&
|
||||
payload.exp;
|
||||
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const isExpired = payload.exp < now;
|
||||
|
||||
let resultHTML = '<div class="result ' + (isValidJWT ? 'success' : 'error') + '">';
|
||||
resultHTML += '<h4>JWT 解码结果:</h4>';
|
||||
resultHTML += '<p><strong>Header:</strong></p>';
|
||||
resultHTML += '<pre>' + JSON.stringify(header, null, 2) + '</pre>';
|
||||
resultHTML += '<p><strong>Payload:</strong></p>';
|
||||
resultHTML += '<pre>' + JSON.stringify(payload, null, 2) + '</pre>';
|
||||
|
||||
if (isExpired) {
|
||||
resultHTML += '<p style="color: red;"><strong>⚠️ Token已过期!</strong></p>';
|
||||
} else {
|
||||
const expireDate = new Date(payload.exp * 1000);
|
||||
resultHTML += '<p style="color: green;"><strong>✅ Token有效,过期时间: ' + expireDate.toLocaleString() + '</strong></p>';
|
||||
}
|
||||
|
||||
if (isValidJWT) {
|
||||
resultHTML += '<p style="color: green;"><strong>✅ 这是有效的iMeeting JWT token!</strong></p>';
|
||||
} else {
|
||||
resultHTML += '<p style="color: red;"><strong>❌ 这不是有效的iMeeting JWT token</strong></p>';
|
||||
}
|
||||
|
||||
resultHTML += '</div>';
|
||||
resultDiv.innerHTML = resultHTML;
|
||||
|
||||
} catch (error) {
|
||||
resultDiv.innerHTML = '<div class="result error">解码失败: ' + error.message + '</div>';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
登录调试脚本 - 诊断JWT认证问题
|
||||
|
||||
运行方法:
|
||||
cd /Users/jiliu/工作/projects/imeeting/backend
|
||||
source venv/bin/activate # 激活虚拟环境
|
||||
python test/test_login_debug.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
# 禁用代理以避免本地请求被代理
|
||||
PROXIES = {'http': None, 'https': None}
|
||||
|
||||
def test_backend_connection():
|
||||
"""测试后端连接"""
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/", proxies=PROXIES)
|
||||
print(f"✅ 后端服务连接成功: {response.status_code}")
|
||||
return True
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ 无法连接到后端服务")
|
||||
return False
|
||||
|
||||
def test_login_with_debug(username, password):
|
||||
"""详细的登录测试"""
|
||||
print(f"\n=== 测试登录: {username} ===")
|
||||
|
||||
login_data = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
|
||||
try:
|
||||
print(f"请求URL: {BASE_URL}/api/auth/login")
|
||||
print(f"请求数据: {json.dumps(login_data, ensure_ascii=False)}")
|
||||
|
||||
response = requests.post(f"{BASE_URL}/api/auth/login", json=login_data, proxies=PROXIES)
|
||||
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
print("✅ 登录成功!")
|
||||
print(f"用户信息: {json.dumps(user_data, ensure_ascii=False, indent=2)}")
|
||||
return user_data.get("token")
|
||||
else:
|
||||
print("❌ 登录失败")
|
||||
print(f"错误内容: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 请求异常: {e}")
|
||||
return None
|
||||
|
||||
def test_authenticated_request(token):
|
||||
"""测试认证请求"""
|
||||
if not token:
|
||||
print("❌ 没有有效token,跳过认证测试")
|
||||
return
|
||||
|
||||
print(f"\n=== 测试认证请求 ===")
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
try:
|
||||
# 测试 /api/auth/me
|
||||
print("测试 /api/auth/me")
|
||||
response = requests.get(f"{BASE_URL}/api/auth/me", headers=headers, proxies=PROXIES)
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
print("✅ 认证请求成功")
|
||||
print(f"用户信息: {json.dumps(response.json(), ensure_ascii=False, indent=2)}")
|
||||
else:
|
||||
print("❌ 认证请求失败")
|
||||
print(f"错误: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 认证请求异常: {e}")
|
||||
|
||||
def check_database_users():
|
||||
"""检查数据库用户"""
|
||||
try:
|
||||
from app.core.database import get_db_connection
|
||||
|
||||
print(f"\n=== 检查数据库用户 ===")
|
||||
with get_db_connection() as connection:
|
||||
cursor = connection.cursor(dictionary=True)
|
||||
cursor.execute("SELECT user_id, username, caption, email FROM users LIMIT 10")
|
||||
users = cursor.fetchall()
|
||||
|
||||
print(f"数据库中的用户 (前10个):")
|
||||
for user in users:
|
||||
print(f" - ID: {user['user_id']}, 用户名: {user['username']}, 名称: {user['caption']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 无法访问数据库: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("JWT登录调试工具")
|
||||
print("=" * 50)
|
||||
|
||||
# 1. 测试后端连接
|
||||
if not test_backend_connection():
|
||||
exit(1)
|
||||
|
||||
# 2. 检查数据库用户
|
||||
check_database_users()
|
||||
|
||||
# 3. 测试登录
|
||||
username = input("\n请输入用户名 (默认: mula): ").strip() or "mula"
|
||||
password = input("请输入密码 (默认: 781126): ").strip() or "781126"
|
||||
|
||||
token = test_login_with_debug(username, password)
|
||||
|
||||
# 4. 测试认证请求
|
||||
test_authenticated_request(token)
|
||||
|
||||
print("\n=== 调试完成 ===")
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Redis JWT Token 验证脚本
|
||||
用于检查JWT token是否正确存储在Redis中
|
||||
|
||||
运行方法:
|
||||
cd /Users/jiliu/工作/projects/imeeting/backend
|
||||
source venv/bin/activate # 激活虚拟环境
|
||||
python test/test_redis_jwt.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import redis
|
||||
import json
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
try:
|
||||
from app.core.config import REDIS_CONFIG
|
||||
print("✅ 成功导入项目配置")
|
||||
except ImportError as e:
|
||||
print(f"❌ 导入项目配置失败: {e}")
|
||||
print("请确保在 backend 目录下运行: python test/test_redis_jwt.py")
|
||||
sys.exit(1)
|
||||
|
||||
def check_jwt_in_redis():
|
||||
"""检查Redis中的JWT token"""
|
||||
try:
|
||||
# 使用项目配置连接Redis
|
||||
r = redis.Redis(**REDIS_CONFIG)
|
||||
|
||||
# 测试连接
|
||||
r.ping()
|
||||
print("✅ Redis连接成功")
|
||||
print(f"连接配置: {REDIS_CONFIG}")
|
||||
|
||||
# 获取所有token相关的keys
|
||||
token_keys = r.keys("token:*")
|
||||
|
||||
if not token_keys:
|
||||
print("❌ Redis中没有找到JWT token")
|
||||
print("提示: 请先通过前端登录以生成token")
|
||||
return False
|
||||
|
||||
print(f"✅ 找到 {len(token_keys)} 个token记录:")
|
||||
|
||||
for key in token_keys:
|
||||
# 解析key格式: token:user_id:jwt_token
|
||||
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
|
||||
parts = key_str.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
user_id = parts[1]
|
||||
token_preview = parts[2][:20] + "..."
|
||||
ttl = r.ttl(key)
|
||||
value = r.get(key)
|
||||
value_str = value.decode('utf-8') if isinstance(value, bytes) else value
|
||||
|
||||
print(f" - 用户ID: {user_id}")
|
||||
print(f" Token预览: {token_preview}")
|
||||
if ttl > 0:
|
||||
print(f" 剩余时间: {ttl}秒 ({ttl/3600:.1f}小时)")
|
||||
else:
|
||||
print(f" TTL: {ttl} (永不过期)" if ttl == -1 else f" TTL: {ttl} (已过期)")
|
||||
print(f" 状态: {value_str}")
|
||||
print()
|
||||
|
||||
return True
|
||||
|
||||
except redis.ConnectionError:
|
||||
print("❌ 无法连接到Redis服务器")
|
||||
print("请确保Redis服务正在运行:")
|
||||
print(" brew services start redis # macOS")
|
||||
print(" 或 redis-server # 直接启动")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ 检查失败: {e}")
|
||||
return False
|
||||
|
||||
def test_token_operations():
|
||||
"""测试token操作"""
|
||||
try:
|
||||
r = redis.Redis(**REDIS_CONFIG)
|
||||
|
||||
print("\n=== Token操作测试 ===")
|
||||
|
||||
# 模拟创建token
|
||||
test_key = "token:999:test_token_12345"
|
||||
r.setex(test_key, 60, "active")
|
||||
print(f"✅ 创建测试token: {test_key}")
|
||||
|
||||
# 检查token存在
|
||||
if r.exists(test_key):
|
||||
print("✅ Token存在性验证通过")
|
||||
|
||||
# 检查TTL
|
||||
ttl = r.ttl(test_key)
|
||||
print(f"✅ Token TTL: {ttl}秒")
|
||||
|
||||
# 删除测试token
|
||||
r.delete(test_key)
|
||||
print("✅ 清理测试token")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Token操作测试失败: {e}")
|
||||
return False
|
||||
|
||||
def test_jwt_service():
|
||||
"""测试JWT服务"""
|
||||
try:
|
||||
from app.services.jwt_service import jwt_service
|
||||
|
||||
print("\n=== JWT服务测试 ===")
|
||||
|
||||
# 测试创建token
|
||||
test_data = {
|
||||
"user_id": 999,
|
||||
"username": "test_user",
|
||||
"caption": "测试用户"
|
||||
}
|
||||
|
||||
token = jwt_service.create_access_token(test_data)
|
||||
print(f"✅ 创建JWT token: {token[:30]}...")
|
||||
|
||||
# 测试验证token
|
||||
payload = jwt_service.verify_token(token)
|
||||
if payload:
|
||||
print(f"✅ Token验证成功: 用户ID={payload['user_id']}, 用户名={payload['username']}")
|
||||
else:
|
||||
print("❌ Token验证失败")
|
||||
return False
|
||||
|
||||
# 测试撤销token
|
||||
revoked = jwt_service.revoke_token(token, test_data["user_id"])
|
||||
print(f"✅ 撤销token: {'成功' if revoked else '失败'}")
|
||||
|
||||
# 验证撤销后token失效
|
||||
payload_after_revoke = jwt_service.verify_token(token)
|
||||
if not payload_after_revoke:
|
||||
print("✅ Token撤销后验证失败,符合预期")
|
||||
else:
|
||||
print("❌ Token撤销后仍然有效,不符合预期")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ JWT服务测试失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("JWT + Redis 认证系统测试")
|
||||
print("=" * 50)
|
||||
print(f"工作目录: {os.getcwd()}")
|
||||
print(f"测试脚本路径: {__file__}")
|
||||
|
||||
# 检查Redis中的JWT tokens
|
||||
redis_ok = check_jwt_in_redis()
|
||||
|
||||
# 测试token操作
|
||||
operations_ok = test_token_operations()
|
||||
|
||||
# 测试JWT服务
|
||||
jwt_service_ok = test_jwt_service()
|
||||
|
||||
print("=" * 50)
|
||||
if redis_ok and operations_ok and jwt_service_ok:
|
||||
print("✅ JWT + Redis 认证系统工作正常!")
|
||||
else:
|
||||
print("❌ JWT + Redis 认证系统存在问题")
|
||||
print("\n故障排除建议:")
|
||||
print("1. 确保在 backend 目录下运行测试")
|
||||
print("2. 确保Redis服务正在运行")
|
||||
print("3. 确保已安装所有依赖: pip install -r requirements.txt")
|
||||
print("4. 尝试先通过前端登录生成token")
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
JWT Token 过期测试脚本
|
||||
用于测试JWT token的过期、撤销机制,可以模拟指定用户的token失效
|
||||
|
||||
运行方法:
|
||||
cd /Users/jiliu/工作/projects/imeeting/backend
|
||||
source venv/bin/activate # 激活虚拟环境
|
||||
python test/test_token_expiration.py
|
||||
|
||||
功能:
|
||||
1. 登录指定用户并获取token
|
||||
2. 验证token有效性
|
||||
3. 撤销指定用户的所有token(模拟失效)
|
||||
4. 验证撤销后token失效
|
||||
|
||||
期望结果:在网页上登录的用户执行失效命令后,网页会自动登出
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
# 禁用代理以避免本地请求被代理
|
||||
PROXIES = {'http': None, 'https': None}
|
||||
|
||||
def invalidate_user_tokens():
|
||||
"""模拟指定用户的token失效"""
|
||||
print("模拟用户Token失效工具")
|
||||
print("=" * 40)
|
||||
|
||||
# 获取要失效的用户名
|
||||
target_username = input("请输入要失效token的用户名 (默认: mula): ").strip()
|
||||
if not target_username:
|
||||
target_username = "mula"
|
||||
|
||||
# 获取管理员凭据来执行失效操作
|
||||
admin_username = input("请输入管理员用户名 (默认: mula): ").strip()
|
||||
admin_password = input("请输入管理员密码 (默认: 781126): ").strip()
|
||||
|
||||
if not admin_username:
|
||||
admin_username = "mula"
|
||||
if not admin_password:
|
||||
admin_password = "781126"
|
||||
|
||||
try:
|
||||
# 1. 管理员登录获取token
|
||||
print(f"\n步骤1: 管理员登录 ({admin_username})")
|
||||
admin_login_data = {
|
||||
"username": admin_username,
|
||||
"password": admin_password
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/api/auth/login", json=admin_login_data, proxies=PROXIES)
|
||||
if response.status_code != 200:
|
||||
print(f"❌ 管理员登录失败")
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return
|
||||
|
||||
admin_data = response.json()
|
||||
admin_token = admin_data["token"]
|
||||
admin_headers = {"Authorization": f"Bearer {admin_token}"}
|
||||
|
||||
print(f"✅ 管理员登录成功: {admin_data['username']} ({admin_data['caption']})")
|
||||
|
||||
# 2. 如果目标用户不是管理员,先登录目标用户验证token存在
|
||||
if target_username != admin_username:
|
||||
print(f"\n步骤2: 验证目标用户 ({target_username}) 是否存在")
|
||||
target_password = input(f"请输入 {target_username} 的密码 (用于验证): ").strip()
|
||||
if not target_password:
|
||||
print("❌ 需要提供目标用户的密码来验证")
|
||||
return
|
||||
|
||||
target_login_data = {
|
||||
"username": target_username,
|
||||
"password": target_password
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/api/auth/login", json=target_login_data, proxies=PROXIES)
|
||||
if response.status_code != 200:
|
||||
print(f"❌ 目标用户登录失败,无法验证用户存在")
|
||||
return
|
||||
|
||||
target_data = response.json()
|
||||
print(f"✅ 目标用户验证成功: {target_data['username']} ({target_data['caption']})")
|
||||
target_user_id = target_data['user_id']
|
||||
else:
|
||||
target_user_id = admin_data['user_id']
|
||||
|
||||
# 3. 撤销目标用户的所有token
|
||||
print(f"\n步骤3: 撤销用户 {target_username} (ID: {target_user_id}) 的所有token")
|
||||
|
||||
# 使用管理员权限调用新的admin API
|
||||
response = requests.post(f"{BASE_URL}/api/auth/admin/revoke-user-tokens/{target_user_id}",
|
||||
headers=admin_headers, proxies=PROXIES)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"✅ Token撤销成功: {result.get('message', '已撤销所有token')}")
|
||||
|
||||
# 4. 验证token是否真的失效了
|
||||
print(f"\n步骤4: 验证token失效")
|
||||
if target_username != admin_username:
|
||||
# 尝试使用目标用户的token访问protected API
|
||||
target_token = target_data["token"]
|
||||
target_headers = {"Authorization": f"Bearer {target_token}"}
|
||||
|
||||
response = requests.get(f"{BASE_URL}/api/auth/me", headers=target_headers, proxies=PROXIES)
|
||||
if response.status_code == 401:
|
||||
print(f"✅ 验证成功:用户 {target_username} 的token已失效")
|
||||
else:
|
||||
print(f"❌ 验证失败:用户 {target_username} 的token仍然有效")
|
||||
else:
|
||||
# 如果目标用户就是管理员,验证当前管理员token是否失效
|
||||
response = requests.get(f"{BASE_URL}/api/auth/me", headers=admin_headers, proxies=PROXIES)
|
||||
if response.status_code == 401:
|
||||
print(f"✅ 验证成功:用户 {target_username} 的token已失效")
|
||||
else:
|
||||
print(f"❌ 验证失败:用户 {target_username} 的token仍然有效")
|
||||
|
||||
print(f"\n🌟 操作完成!")
|
||||
print(f"如果用户 {target_username} 在网页上已登录,现在应该会自动登出。")
|
||||
print(f"你可以在网页上验证是否自动跳转到登录页面。")
|
||||
|
||||
else:
|
||||
print(f"❌ Token撤销失败: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ 无法连接到后端服务器,请确保服务器正在运行")
|
||||
except Exception as e:
|
||||
print(f"❌ 操作失败: {e}")
|
||||
|
||||
def test_token_expiration():
|
||||
"""测试token过期机制"""
|
||||
print("JWT Token 过期测试")
|
||||
print("=" * 40)
|
||||
|
||||
# 1. 登录获取token
|
||||
username = input("请输入用户名 (默认: test): ").strip()
|
||||
password = input("请输入密码 (默认: test): ").strip()
|
||||
|
||||
# 使用默认值如果输入为空
|
||||
if not username:
|
||||
username = "test"
|
||||
if not password:
|
||||
password = "test"
|
||||
|
||||
login_data = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
|
||||
try:
|
||||
# 登录
|
||||
print(f"正在尝试登录用户: {login_data['username']}")
|
||||
response = requests.post(f"{BASE_URL}/api/auth/login", json=login_data, proxies=PROXIES)
|
||||
if response.status_code != 200:
|
||||
print(f"❌ 登录失败")
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
print(f"请求URL: {BASE_URL}/api/auth/login")
|
||||
print("请检查:")
|
||||
print("1. 后端服务是否正在运行")
|
||||
print("2. 用户名和密码是否正确")
|
||||
print("3. 数据库连接是否正常")
|
||||
return
|
||||
|
||||
user_data = response.json()
|
||||
token = user_data["token"]
|
||||
|
||||
print(f"✅ 登录成功,获得token: {token[:20]}...")
|
||||
|
||||
# 2. 测试token有效性
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
print("\n测试1: 验证token有效性")
|
||||
response = requests.get(f"{BASE_URL}/api/auth/me", headers=headers, proxies=PROXIES)
|
||||
if response.status_code == 200:
|
||||
user_info = response.json()
|
||||
print(f"✅ Token有效,用户: {user_info.get('username')}")
|
||||
else:
|
||||
print(f"❌ Token无效: {response.status_code}")
|
||||
return
|
||||
|
||||
# 3. 测试受保护的API
|
||||
print("\n测试2: 访问受保护的API")
|
||||
response = requests.get(f"{BASE_URL}/api/meetings", headers=headers, proxies=PROXIES)
|
||||
if response.status_code == 200:
|
||||
print("✅ 成功访问会议列表API")
|
||||
else:
|
||||
print(f"❌ 访问受保护API失败: {response.status_code}")
|
||||
|
||||
# 4. 登出token
|
||||
print("\n测试3: 登出token")
|
||||
response = requests.post(f"{BASE_URL}/api/auth/logout", headers=headers, proxies=PROXIES)
|
||||
if response.status_code == 200:
|
||||
print("✅ 登出成功")
|
||||
|
||||
# 5. 验证登出后token失效
|
||||
print("\n测试4: 验证登出后token失效")
|
||||
response = requests.get(f"{BASE_URL}/api/auth/me", headers=headers, proxies=PROXIES)
|
||||
if response.status_code == 401:
|
||||
print("✅ Token已失效,登出成功")
|
||||
else:
|
||||
print(f"❌ Token仍然有效,登出失败: {response.status_code}")
|
||||
else:
|
||||
print(f"❌ 登出失败: {response.status_code}")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ 无法连接到后端服务器,请确保服务器正在运行")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
|
||||
def check_token_format():
|
||||
"""检查token格式是否为JWT"""
|
||||
token = input("\n请粘贴JWT token (或按Enter跳过): ").strip()
|
||||
|
||||
if not token:
|
||||
return
|
||||
|
||||
print(f"\nJWT格式检查:")
|
||||
|
||||
# JWT应该有三个部分,用.分隔
|
||||
parts = token.split('.')
|
||||
if len(parts) != 3:
|
||||
print("❌ 不是有效的JWT格式 (应该有3个部分用.分隔)")
|
||||
return
|
||||
|
||||
try:
|
||||
import base64
|
||||
import json
|
||||
|
||||
# 解码header
|
||||
header_padding = parts[0] + '=' * (4 - len(parts[0]) % 4)
|
||||
header = json.loads(base64.urlsafe_b64decode(header_padding))
|
||||
|
||||
# 解码payload
|
||||
payload_padding = parts[1] + '=' * (4 - len(parts[1]) % 4)
|
||||
payload = json.loads(base64.urlsafe_b64decode(payload_padding))
|
||||
|
||||
print("✅ JWT格式有效")
|
||||
print(f"算法: {header.get('alg')}")
|
||||
print(f"类型: {header.get('typ')}")
|
||||
print(f"用户ID: {payload.get('user_id')}")
|
||||
print(f"用户名: {payload.get('username')}")
|
||||
|
||||
if 'exp' in payload:
|
||||
exp_time = datetime.fromtimestamp(payload['exp'])
|
||||
print(f"过期时间: {exp_time}")
|
||||
|
||||
if datetime.now() > exp_time:
|
||||
print("❌ Token已过期")
|
||||
else:
|
||||
print("✅ Token未过期")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ JWT解码失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("JWT Token 测试工具")
|
||||
print("=" * 50)
|
||||
print(f"工作目录: {os.getcwd()}")
|
||||
print(f"测试脚本路径: {__file__}")
|
||||
print()
|
||||
|
||||
print("请选择功能:")
|
||||
print("1. 模拟指定用户Token失效 (推荐)")
|
||||
print("2. 完整Token过期测试")
|
||||
print("3. JWT格式检查")
|
||||
|
||||
choice = input("\n请输入选项 (1-3, 默认: 1): ").strip()
|
||||
|
||||
if choice == "2":
|
||||
test_token_expiration()
|
||||
check_token_format()
|
||||
elif choice == "3":
|
||||
check_token_format()
|
||||
else:
|
||||
# 默认选择1
|
||||
invalidate_user_tokens()
|
||||
|
||||
print("\n=== 测试完成 ===")
|
||||
print("如果测试失败,请检查:")
|
||||
print("1. 确保后端服务正在运行: python main.py")
|
||||
print("2. 确保在 backend 目录下运行测试")
|
||||
print("3. 确保Redis服务正在运行")
|
||||
print("4. 如果选择了选项1,请在网页上验证用户是否自动登出")
|
||||
Binary file not shown.
Loading…
Reference in New Issue