diff --git a/app/api/endpoints/meetings.py b/app/api/endpoints/meetings.py index a639736..dbfc1f6 100644 --- a/app/api/endpoints/meetings.py +++ b/app/api/endpoints/meetings.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from app.models.models import Meeting, TranscriptSegment, CreateMeetingRequest, UpdateMeetingRequest +from app.models.models import Meeting, TranscriptSegment, CreateMeetingRequest, UpdateMeetingRequest, SpeakerTagUpdateRequest, BatchSpeakerTagUpdateRequest from app.core.database import get_db_connection from app.core.config import BASE_DIR, UPLOAD_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSIONS, ALLOWED_IMAGE_EXTENSIONS, MAX_FILE_SIZE, MAX_IMAGE_SIZE from app.services.qiniu_service import qiniu_service @@ -125,7 +125,7 @@ def get_meeting_transcript(meeting_id: int): # Get transcript segments transcript_query = ''' - SELECT segment_id, meeting_id, speaker_tag, start_time_ms, end_time_ms, text_content + SELECT segment_id, meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content FROM transcript_segments WHERE meeting_id = %s ORDER BY start_time_ms ASC @@ -133,7 +133,19 @@ def get_meeting_transcript(meeting_id: int): cursor.execute(transcript_query, (meeting_id,)) segments = cursor.fetchall() - return [TranscriptSegment(**segment) for segment in segments] + transcript_segments = [] + for segment in segments: + transcript_segments.append(TranscriptSegment( + segment_id=segment['segment_id'], + meeting_id=segment['meeting_id'], + speaker_id=segment['speaker_id'], + speaker_tag=segment['speaker_tag'] if segment['speaker_tag'] else f"发言人 {segment['speaker_id']}", + start_time_ms=segment['start_time_ms'], + end_time_ms=segment['end_time_ms'], + text_content=segment['text_content'] + )) + + return transcript_segments @router.post("/meetings") def create_meeting(meeting_request: CreateMeetingRequest): @@ -443,3 +455,51 @@ async def upload_image( "file_name": image_file.filename, "file_path": '/'+ str(relative_path) } + +# 发言人标签更新接口 +@router.put("/meetings/{meeting_id}/speaker-tags") +def update_speaker_tag(meeting_id: int, request: SpeakerTagUpdateRequest): + """更新单个发言人标签""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + + # 更新指定meeting_id和speaker_id的所有记录的speaker_tag + update_query = """ + UPDATE transcript_segments + SET speaker_tag = %s + WHERE meeting_id = %s AND speaker_id = %s + """ + cursor.execute(update_query, (request.new_tag, meeting_id, request.speaker_id)) + + if cursor.rowcount == 0: + raise HTTPException(status_code=404, detail="No segments found for this speaker") + + connection.commit() + return {'message': 'Speaker tag updated successfully', 'updated_count': cursor.rowcount} + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update speaker tag: {str(e)}") + +@router.put("/meetings/{meeting_id}/speaker-tags/batch") +def batch_update_speaker_tags(meeting_id: int, request: BatchSpeakerTagUpdateRequest): + """批量更新发言人标签""" + try: + with get_db_connection() as connection: + cursor = connection.cursor() + + total_updated = 0 + for update_item in request.updates: + update_query = """ + UPDATE transcript_segments + SET speaker_tag = %s + WHERE meeting_id = %s AND speaker_id = %s + """ + cursor.execute(update_query, (update_item.new_tag, meeting_id, update_item.speaker_id)) + total_updated += cursor.rowcount + + connection.commit() + return {'message': 'Speaker tags updated successfully', 'total_updated': total_updated} + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to batch update speaker tags: {str(e)}") diff --git a/app/models/models.py b/app/models/models.py index 44100e2..22e6f38 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -41,6 +41,7 @@ class Meeting(BaseModel): class TranscriptSegment(BaseModel): segment_id: int meeting_id: int + speaker_id: Optional[int] = None # AI解析的原始结果 speaker_tag: str start_time_ms: int end_time_ms: int @@ -57,3 +58,10 @@ class UpdateMeetingRequest(BaseModel): meeting_time: Optional[datetime.datetime] summary: Optional[str] attendee_ids: list[int] + +class SpeakerTagUpdateRequest(BaseModel): + speaker_id: int # 使用原始speaker_id(整数) + new_tag: str + +class BatchSpeakerTagUpdateRequest(BaseModel): + updates: List[SpeakerTagUpdateRequest] diff --git a/app/services/ai_service.py b/app/services/ai_service.py index d2b128b..92e1f03 100644 --- a/app/services/ai_service.py +++ b/app/services/ai_service.py @@ -54,9 +54,11 @@ class AIService: segments_to_insert = [] for transcript in data.get('transcripts', []): for sentence in transcript.get('sentences', []): + speaker_id = sentence.get('speaker_id', -1) segments_to_insert.append(( meeting_id, - sentence.get('speaker_id', 'Unknown'), + speaker_id, # For the new speaker_id column + speaker_id, # For the speaker_tag column (initial value) sentence.get('begin_time'), sentence.get('end_time'), sentence.get('text') @@ -76,8 +78,8 @@ class AIService: print(f"Deleted existing segments for meeting_id: {meeting_id}") insert_query = ''' - INSERT INTO transcript_segments (meeting_id, speaker_tag, start_time_ms, end_time_ms, text_content) - VALUES (%s, %s, %s, %s, %s) + INSERT INTO transcript_segments (meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content) + VALUES (%s, %s, %s, %s, %s, %s) ''' cursor.executemany(insert_query, segments_to_insert) connection.commit() @@ -96,7 +98,7 @@ if __name__ == '__main__': # 1. Make sure you have a meeting with meeting_id = 1 in your database. # 2. Make sure the audio file URL is correct and accessible. - test_meeting_id = 37 + test_meeting_id = 38 # Please replace with your own publicly accessible audio file URL test_file_urls = ['http://t0vogyxkz.hn-bkt.clouddn.com/record/meeting_records_2.mp3'] diff --git a/uploads/audio/38/c08b25f9-6029-4495-ad5d-19512bbc10a2.mp3 b/uploads/audio/38/c08b25f9-6029-4495-ad5d-19512bbc10a2.mp3 new file mode 100644 index 0000000..0fd96c9 Binary files /dev/null and b/uploads/audio/38/c08b25f9-6029-4495-ad5d-19512bbc10a2.mp3 differ