sayakpaul's picture
sayakpaul HF Staff
uo
99ef3c9
import os
import tempfile
import gradio as gr
import spaces
from separator import StemSeparatorService
STEMS = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
# Load model weights to CPU at startup
separator = StemSeparatorService()
separator.load_model()
@spaces.GPU
def separate(audio_file: str, stems: list[str], output_format: str):
"""Separate audio into stems. Returns one audio file per stem (None if not requested)."""
if audio_file is None:
raise gr.Error("No audio file provided.")
if not stems:
raise gr.Error("No stems selected.")
output_dir = tempfile.mkdtemp(prefix="stems-")
separator.move_to("cuda")
try:
result = separator.separate(
audio_file, output_dir, stems, output_format,
progress_callback=lambda state, pct: None,
)
finally:
separator.move_to("cpu")
# Always return one value per stem in fixed order
outputs = []
for stem_name in STEMS:
if stem_name in result:
outputs.append(os.path.join(output_dir, result[stem_name]))
else:
outputs.append(None)
return outputs
demo = gr.Interface(
fn=separate,
inputs=[
gr.Audio(type="filepath", label="Audio"),
gr.CheckboxGroup(choices=STEMS, value=STEMS, label="Stems"),
gr.Dropdown(choices=["wav", "mp3", "aac"], value="wav", label="Output Format"),
],
outputs=[gr.Audio(type="filepath", label=s) for s in STEMS],
title="Stem Separator Inference",
description="BS-RoFormer stem separation API. Upload audio and select stems to extract.",
api_name="separate",
)
if __name__ == "__main__":
demo.launch()