增加人声分离标签

main
mula.liu 2025-08-25 16:10:29 +08:00
parent 06cffe7cfa
commit 3528ba717d
4 changed files with 77 additions and 7 deletions

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, UploadFile, File, Form from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from app.models.models import Meeting, TranscriptSegment, CreateMeetingRequest, UpdateMeetingRequest from app.models.models import Meeting, TranscriptSegment, CreateMeetingRequest, UpdateMeetingRequest, SpeakerTagUpdateRequest, BatchSpeakerTagUpdateRequest
from app.core.database import get_db_connection from app.core.database import get_db_connection
from app.core.config import BASE_DIR, UPLOAD_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSIONS, ALLOWED_IMAGE_EXTENSIONS, MAX_FILE_SIZE, MAX_IMAGE_SIZE from app.core.config import BASE_DIR, UPLOAD_DIR, AUDIO_DIR, MARKDOWN_DIR, ALLOWED_EXTENSIONS, ALLOWED_IMAGE_EXTENSIONS, MAX_FILE_SIZE, MAX_IMAGE_SIZE
from app.services.qiniu_service import qiniu_service from app.services.qiniu_service import qiniu_service
@ -125,7 +125,7 @@ def get_meeting_transcript(meeting_id: int):
# Get transcript segments # Get transcript segments
transcript_query = ''' transcript_query = '''
SELECT segment_id, meeting_id, speaker_tag, start_time_ms, end_time_ms, text_content SELECT segment_id, meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content
FROM transcript_segments FROM transcript_segments
WHERE meeting_id = %s WHERE meeting_id = %s
ORDER BY start_time_ms ASC ORDER BY start_time_ms ASC
@ -133,7 +133,19 @@ def get_meeting_transcript(meeting_id: int):
cursor.execute(transcript_query, (meeting_id,)) cursor.execute(transcript_query, (meeting_id,))
segments = cursor.fetchall() segments = cursor.fetchall()
return [TranscriptSegment(**segment) for segment in segments] transcript_segments = []
for segment in segments:
transcript_segments.append(TranscriptSegment(
segment_id=segment['segment_id'],
meeting_id=segment['meeting_id'],
speaker_id=segment['speaker_id'],
speaker_tag=segment['speaker_tag'] if segment['speaker_tag'] else f"发言人 {segment['speaker_id']}",
start_time_ms=segment['start_time_ms'],
end_time_ms=segment['end_time_ms'],
text_content=segment['text_content']
))
return transcript_segments
@router.post("/meetings") @router.post("/meetings")
def create_meeting(meeting_request: CreateMeetingRequest): def create_meeting(meeting_request: CreateMeetingRequest):
@ -443,3 +455,51 @@ async def upload_image(
"file_name": image_file.filename, "file_name": image_file.filename,
"file_path": '/'+ str(relative_path) "file_path": '/'+ str(relative_path)
} }
# 发言人标签更新接口
@router.put("/meetings/{meeting_id}/speaker-tags")
def update_speaker_tag(meeting_id: int, request: SpeakerTagUpdateRequest):
"""更新单个发言人标签"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
# 更新指定meeting_id和speaker_id的所有记录的speaker_tag
update_query = """
UPDATE transcript_segments
SET speaker_tag = %s
WHERE meeting_id = %s AND speaker_id = %s
"""
cursor.execute(update_query, (request.new_tag, meeting_id, request.speaker_id))
if cursor.rowcount == 0:
raise HTTPException(status_code=404, detail="No segments found for this speaker")
connection.commit()
return {'message': 'Speaker tag updated successfully', 'updated_count': cursor.rowcount}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update speaker tag: {str(e)}")
@router.put("/meetings/{meeting_id}/speaker-tags/batch")
def batch_update_speaker_tags(meeting_id: int, request: BatchSpeakerTagUpdateRequest):
"""批量更新发言人标签"""
try:
with get_db_connection() as connection:
cursor = connection.cursor()
total_updated = 0
for update_item in request.updates:
update_query = """
UPDATE transcript_segments
SET speaker_tag = %s
WHERE meeting_id = %s AND speaker_id = %s
"""
cursor.execute(update_query, (update_item.new_tag, meeting_id, update_item.speaker_id))
total_updated += cursor.rowcount
connection.commit()
return {'message': 'Speaker tags updated successfully', 'total_updated': total_updated}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to batch update speaker tags: {str(e)}")

View File

@ -41,6 +41,7 @@ class Meeting(BaseModel):
class TranscriptSegment(BaseModel): class TranscriptSegment(BaseModel):
segment_id: int segment_id: int
meeting_id: int meeting_id: int
speaker_id: Optional[int] = None # AI解析的原始结果
speaker_tag: str speaker_tag: str
start_time_ms: int start_time_ms: int
end_time_ms: int end_time_ms: int
@ -57,3 +58,10 @@ class UpdateMeetingRequest(BaseModel):
meeting_time: Optional[datetime.datetime] meeting_time: Optional[datetime.datetime]
summary: Optional[str] summary: Optional[str]
attendee_ids: list[int] attendee_ids: list[int]
class SpeakerTagUpdateRequest(BaseModel):
speaker_id: int # 使用原始speaker_id整数
new_tag: str
class BatchSpeakerTagUpdateRequest(BaseModel):
updates: List[SpeakerTagUpdateRequest]

View File

@ -54,9 +54,11 @@ class AIService:
segments_to_insert = [] segments_to_insert = []
for transcript in data.get('transcripts', []): for transcript in data.get('transcripts', []):
for sentence in transcript.get('sentences', []): for sentence in transcript.get('sentences', []):
speaker_id = sentence.get('speaker_id', -1)
segments_to_insert.append(( segments_to_insert.append((
meeting_id, meeting_id,
sentence.get('speaker_id', 'Unknown'), speaker_id, # For the new speaker_id column
speaker_id, # For the speaker_tag column (initial value)
sentence.get('begin_time'), sentence.get('begin_time'),
sentence.get('end_time'), sentence.get('end_time'),
sentence.get('text') sentence.get('text')
@ -76,8 +78,8 @@ class AIService:
print(f"Deleted existing segments for meeting_id: {meeting_id}") print(f"Deleted existing segments for meeting_id: {meeting_id}")
insert_query = ''' insert_query = '''
INSERT INTO transcript_segments (meeting_id, speaker_tag, start_time_ms, end_time_ms, text_content) INSERT INTO transcript_segments (meeting_id, speaker_id, speaker_tag, start_time_ms, end_time_ms, text_content)
VALUES (%s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s)
''' '''
cursor.executemany(insert_query, segments_to_insert) cursor.executemany(insert_query, segments_to_insert)
connection.commit() connection.commit()
@ -96,7 +98,7 @@ if __name__ == '__main__':
# 1. Make sure you have a meeting with meeting_id = 1 in your database. # 1. Make sure you have a meeting with meeting_id = 1 in your database.
# 2. Make sure the audio file URL is correct and accessible. # 2. Make sure the audio file URL is correct and accessible.
test_meeting_id = 37 test_meeting_id = 38
# Please replace with your own publicly accessible audio file URL # Please replace with your own publicly accessible audio file URL
test_file_urls = ['http://t0vogyxkz.hn-bkt.clouddn.com/record/meeting_records_2.mp3'] test_file_urls = ['http://t0vogyxkz.hn-bkt.clouddn.com/record/meeting_records_2.mp3']