192 lines
6.5 KiB
Python
192 lines
6.5 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import scipy.signal
|
|
from scipy.io import wavfile
|
|
import scipy.stats as st
|
|
|
|
import sounddevice as sd
|
|
from pesq import pesq
|
|
from pns.noise_suppressor import NoiseSuppressor
|
|
|
|
import random
|
|
import json
|
|
|
|
|
|
|
|
#SIGNAL_PATH = "speechfiles/sp01.wav"
|
|
#NOISE_PATH = "noisefiles/white.dat"
|
|
|
|
|
|
# Scale the signal to the range [-1,1]
|
|
def normalize_signal(signal):
|
|
min_amp = np.min(signal)
|
|
normalized_signal = signal - min_amp
|
|
max_amp = np.max(normalized_signal)
|
|
normalized_signal *= 2.0/max_amp
|
|
normalized_signal -= 1
|
|
return normalized_signal
|
|
|
|
|
|
# Load an audio file from disk
|
|
def load_audiofile(path):
|
|
sound_data = []
|
|
sample_rate = 8000
|
|
# Load .dat files as sound files sampled at 8[kHz]
|
|
if path[-3:] == "dat":
|
|
with open(path, "r") as sound_file:
|
|
sound_data_strings = sound_file.readlines()
|
|
for data_string in sound_data_strings:
|
|
sound_data.append(eval(data_string.strip()))
|
|
sound_data = np.array(sound_data, dtype=np.float64)
|
|
elif path[-3:] == "wav":
|
|
sample_rate, sound_data = wavfile.read(path)
|
|
# Make sure it is nparray of floats (trust me bro, normalizing yells at you if its ints)
|
|
sound_data = np.array(sound_data, dtype=np.float64)
|
|
sound_data = normalize_signal(sound_data)
|
|
return sample_rate, sound_data
|
|
|
|
|
|
# Add noise to a signal with a desired SNR
|
|
def add_noise(signal, noise, snr):
|
|
len_signal = len(signal)
|
|
len_noise = len(noise)
|
|
|
|
# Get a random crop of the noise to match the length of the signal
|
|
noise_crop_start = random.randrange(len_noise-len_signal)
|
|
noise_crop = noise[noise_crop_start:noise_crop_start+len_signal]
|
|
|
|
# Calculate the power of the signal and noise
|
|
noise_power = np.linalg.norm(noise_crop, 2)
|
|
signal_power = np.linalg.norm(signal, 2)
|
|
|
|
# Adjust the noise level to match desired SNR
|
|
u = 10**(snr/20)
|
|
desired_noise_power = signal_power/u
|
|
ratio = desired_noise_power / noise_power
|
|
noise_crop *= ratio
|
|
|
|
noisy_signal = signal + noise_crop
|
|
return normalize_signal(noisy_signal)
|
|
|
|
|
|
def load_all_signals():
|
|
signal_sample_rate = None
|
|
all_signals = []
|
|
|
|
# Compose signal paths for the 30 sentences
|
|
signal_paths = []
|
|
for i in range(1,30+1):
|
|
signal_paths.append(f"speechfiles/sp{i:02}.wav")
|
|
|
|
# Load all signals
|
|
for signal_path in signal_paths:
|
|
new_signal_sample_rate, signal_data = load_audiofile(signal_path)
|
|
assert signal_sample_rate == new_signal_sample_rate or signal_sample_rate == None, "Non-uniform signal sampling rates"
|
|
signal_sample_rate = new_signal_sample_rate
|
|
all_signals.append(signal_data)
|
|
|
|
return signal_sample_rate, all_signals
|
|
|
|
|
|
def load_all_noises():
|
|
noise_paths = ("noisefiles/white.dat", "noisefiles/train.dat", "noisefiles/street.dat", "noisefiles/exhibition.dat")
|
|
noise_sample_rate = None
|
|
all_noises = []
|
|
|
|
# Load all noises
|
|
for noise_path in noise_paths:
|
|
new_noise_sample_rate, noise_data = load_audiofile(noise_path)
|
|
assert noise_sample_rate == new_noise_sample_rate or noise_sample_rate == None, "Non-uniform noise sampling rates"
|
|
noise_sample_rate = new_noise_sample_rate
|
|
all_noises.append(noise_data)
|
|
|
|
return noise_sample_rate, all_noises
|
|
|
|
|
|
def enhanced(signal, sample_rate):
|
|
'''
|
|
noise_suppressor = NoiseSuppressor(sample_rate)
|
|
frame_size = noise_suppressor.get_frame_size()
|
|
filtered_signal = np.zeros(len(signal))
|
|
|
|
k = 0
|
|
while k + frame_size < len(signal):
|
|
frame = signal[k : k+frame_size]
|
|
filtered_signal[k : k+frame_size] = noise_suppressor.process_frame(frame)
|
|
k += frame_size
|
|
'''
|
|
filtered_signal = scipy.signal.wiener(signal)
|
|
return normalize_signal(filtered_signal)
|
|
|
|
|
|
def calculate_pesqs(all_signals, all_noises, sample_rate, snrs):
|
|
noises = ('white', 'train', 'street', 'exhibition')
|
|
noisy_pesqs = {}
|
|
filtered_pesqs = {}
|
|
|
|
for snr in snrs:
|
|
noisy_pesqs[snr] = []
|
|
filtered_pesqs[snr] = []
|
|
for i, noise_data in enumerate(all_noises):
|
|
for j, signal_data in enumerate(all_signals):
|
|
noisy_signal = add_noise(signal_data, noise_data, snr)
|
|
filtered_signal = enhanced(noisy_signal, sample_rate)
|
|
#wavfile.write(f"noisy/sp{j+1:02}_{noises[i]}_snr{snr}.wav", sample_rate, )
|
|
#sd.play(normalize_signal(signal_data), samplerate=sample_rate, blocking=True)
|
|
#sd.play(normalize_signal(noisy_signal), samplerate=sample_rate, blocking=True)
|
|
#sd.play(normalize_signal(filtered_signal), samplerate=sample_rate, blocking=True)
|
|
|
|
noisy_pesq = pesq(sample_rate, signal_data, noisy_signal, mode='nb')
|
|
filtered_pesq = pesq(sample_rate, signal_data, filtered_signal, mode='nb')
|
|
|
|
noisy_pesqs[snr].append(noisy_pesq)
|
|
filtered_pesqs[snr].append(filtered_pesq)
|
|
|
|
return noisy_pesqs, filtered_pesqs
|
|
|
|
|
|
def main():
|
|
# SNR in dB
|
|
snrs = (0, 10, 20, 30)
|
|
|
|
# Load all signals and all noises
|
|
signal_sample_rate, all_signals = load_all_signals()
|
|
noise_sample_rate, all_noises = load_all_noises()
|
|
|
|
assert signal_sample_rate == noise_sample_rate, "Signal and noise sampling rates didn't match."
|
|
sample_rate = signal_sample_rate
|
|
|
|
noisy_pesqs, filtered_pesqs = calculate_pesqs(all_signals, all_noises, sample_rate, snrs)
|
|
|
|
n_pesq = [noisy_pesqs[snr][0] for snr in snrs]
|
|
f_pesq = [filtered_pesqs[snr][0] for snr in snrs]
|
|
plt.plot(snrs, n_pesq, label="Noisy")
|
|
plt.plot(snrs, f_pesq, label="Filtered")
|
|
plt.xlabel("SNR [dB]")
|
|
plt.ylabel("PESQ")
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(9,9), sharey=True)
|
|
|
|
# Calculate the 95% confidence interval for the noise and filter PESQ at each SNR
|
|
for i, snr in enumerate(snrs):
|
|
data_noise = np.array(noisy_pesqs[snr])
|
|
data_filter = np.array(filtered_pesqs[snr])
|
|
|
|
noise_ci = st.t.interval(0.95, len(data_noise)-1, loc=np.mean(data_noise), scale=st.sem(data_noise))
|
|
filter_ci = st.t.interval(0.95, len(data_filter)-1, loc=np.mean(data_filter), scale=st.sem(data_filter))
|
|
|
|
axes[i].vlines((1,2), (min(noise_ci), min(filter_ci)), (max(noise_ci), max(filter_ci)))
|
|
axes[i].set_xticks(np.arange(1, 3), labels=["Unfiltered", "Filtered"])
|
|
axes[i].set_xlim(0.25, 2.75)
|
|
axes[i].set_xlabel(f"{snr}[dB] SNR")
|
|
axes[0].set_ylabel("PESQ")
|
|
|
|
#plt.savefig("PESQ_Confidence_Interval.png")
|
|
plt.show()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|