import numpy as np
import scipy.io.wavfile as wf
import matplotlib.pyplot as plt
from tqdm import tqdm


class VoiceActivityDetector(object):
    """使用声音能量去检测音频文件中的语音活动"""

    def __init__(self, wave_input_filename):
        self._read_wav(wave_input_filename)._convert_to_mono()
        self.sample_window = 0.02  # 20 ms
        self.sample_overlap = 0.01  # 10ms # 滑动步长
        self.speech_window = 0.5  # half a second
        self.speech_energy_threshold = 0.6  # 60% of energy in voice band
        self.speech_start_band = 300
        self.speech_end_band = 3000

    def _read_wav(self, wave_file):
        self.rate, self.data = wf.read(wave_file)
        # print("self.rate", self.rate)  # 44100 采样频率 模拟信号存储为数字信号时，一秒钟采样多少次进行存储 这里表示44100个数是一秒
        # print("self.data", self.data.shape)  # (26460000, 2)
        self.channels = len(self.data.shape)
        self.filename = wave_file
        return self

    def _convert_to_mono(self):
        if self.channels == 2:
            self.data = np.mean(self.data, axis=1, dtype=self.data.dtype)  # 均值压缩 # (26460000, )
            self.channels = 1
        return self

    def _calculate_frequencies(self, audio_data):
        data_freq = np.fft.fftfreq(len(audio_data), 1.0 / self.rate)
        data_freq = data_freq[1:]
        return data_freq

    def _calculate_amplitude(self, audio_data):
        data_ampl = np.abs(np.fft.fft(audio_data))
        data_ampl = data_ampl[1:]
        return data_ampl

    def _calculate_energy(self, data):
        data_amplitude = self._calculate_amplitude(data)
        data_energy = data_amplitude ** 2
        return data_energy

    def _znormalize_energy(self, data_energy):
        energy_mean = np.mean(data_energy)
        energy_std = np.std(data_energy)
        energy_znorm = (data_energy - energy_mean) / energy_std
        return energy_znorm

    def _connect_energy_with_frequencies(self, data_freq, data_energy):
        energy_freq = {}
        for (i, freq) in enumerate(data_freq):
            if abs(freq) not in energy_freq:
                energy_freq[abs(freq)] = data_energy[i] * 2
        return energy_freq

    def _calculate_normalized_energy(self, data):
        data_freq = self._calculate_frequencies(data)
        data_energy = self._calculate_energy(data)
        # data_energy = self._znormalize_energy(data_energy) #znorm brings worse results
        energy_freq = self._connect_energy_with_frequencies(data_freq, data_energy)
        return energy_freq

    def _sum_energy_in_band(self, energy_frequencies, start_band, end_band):
        sum_energy = 0
        for f in energy_frequencies.keys():
            if start_band < f < end_band:
                sum_energy += energy_frequencies[f]
        return sum_energy

    def _median_filter(self, x, k):
        assert k % 2 == 1, "Median filter length must be odd."
        assert x.ndim == 1, "Input must be one-dimensional."
        k2 = (k - 1) // 2
        y = np.zeros((len(x), k), dtype=x.dtype)
        y[:, k2] = x
        for i in range(k2):
            j = k2 - i
            y[j:, i] = x[:-j]
            y[:j, i] = x[0]
            y[:-j, -(i + 1)] = x[j:]
            y[-j:, -(i + 1)] = x[-1]
        return np.median(y, axis=1)

    def _smooth_speech_detection(self, detected_windows):
        median_window = int(self.speech_window / self.sample_window)
        if median_window % 2 == 0:
            median_window = median_window - 1
        median_energy = self._median_filter(detected_windows[:, 1], median_window)
        return median_energy

    def convert_to_labels(self, detected_windows):
        """ Takes as input array of window numbers and speech flags from speech
        detection and convert speech flags to time intervals of speech.
        Output is array of dictionaries with speech intervals.
        """
        speech_time = []
        is_speech = 0
        for window in detected_windows:
            if window[1] == 1.0 and is_speech == 0:
                is_speech = 1
                speech_label = dict()
                speech_time_start = window[0] / self.rate
                speech_label['speech_begin'] = speech_time_start
            if window[1] == 0.0 and is_speech == 1:
                is_speech = 0
                speech_time_end = window[0] / self.rate
                speech_label['speech_end'] = speech_time_end
                speech_time.append(speech_label)
        return speech_time

    def plot_detected_speech_regions(self):
        """ Performs speech detection and plot original signal and speech regions.
        """
        data = self.data
        detected_windows = self.detect_speech()
        data_speech = np.zeros(len(data))
        it = np.nditer(detected_windows[:, 0], flags=['f_index'])
        while not it.finished:
            data_speech[int(it[0])] = data[int(it[0])] * detected_windows[it.index, 1]
            it.iternext()
        plt.figure()
        plt.plot(data_speech)
        plt.plot(data)
        plt.show()
        return self

    def detect_speech(self):
        """
        Detects speech regions based on ratio between speech band energy
        and total energy.
        Output is array of window numbers and speech flags (1 - speech, 0 - nonspeech).
        """
        detected_windows = np.array([])
        sample_window = int(self.rate * self.sample_window)
        sample_overlap = int(self.rate * self.sample_overlap)
        data = self.data
        sample_start = 0
        start_band = self.speech_start_band
        end_band = self.speech_end_band
        progress_bar = tqdm(
            total=len(data),
            unit='frames',
            dynamic_ncols=True)
        while sample_start < (len(data) - sample_window):
            # print(sample_start, int(sample_start * 100 / len(data)))
            sample_end = sample_start + sample_window
            if sample_end >= len(data):
                sample_end = len(data) - 1
            data_window = data[sample_start:sample_end]
            energy_freq = self._calculate_normalized_energy(data_window)
            sum_voice_energy = self._sum_energy_in_band(energy_freq, start_band, end_band)
            sum_full_energy = sum(energy_freq.values())
            speech_ratio = sum_voice_energy / sum_full_energy
            # Hipothesis is that when there is a speech sequence we have ratio of energies more than Threshold
            speech_ratio = speech_ratio > self.speech_energy_threshold  # 判断这个小区间的说话区间的能量比例是否达到总体比例的一个比例，达到了就是True
            detected_windows = np.append(detected_windows, [sample_start, speech_ratio])
            sample_start += sample_overlap
            progress_bar.update(sample_overlap)
        detected_windows = detected_windows.reshape(int(len(detected_windows) / 2), 2)
        detected_windows[:, 1] = self._smooth_speech_detection(detected_windows)
        return detected_windows

    def process_labels(self, labels, margin=12):
        select_labels = []
        before_begin = before_end = 0
        for i in range(len(labels)):
            speech_begin = labels[i]["speech_begin"]
            speech_end = labels[i]["speech_end"]
            speech = {}
            if before_begin == 0 and before_end == 0:
                before_begin = speech["speech_begin"] = speech_begin
                before_end = speech["speech_end"] = speech_end
                select_labels.append(speech)
                continue
            if speech_begin - before_end < margin:  # merge regions with an interval of less than margin=12 seconds
                before_begin = speech["speech_begin"] = before_begin
                before_end = speech["speech_end"] = speech_end
                select_labels.pop()
                select_labels.append(speech)
            else:
                before_begin = speech["speech_begin"] = speech_begin
                before_end = speech["speech_end"] = speech_end
                select_labels.append(speech)
        # for i in range(len(select_labels)):
        #     if select_labels[i]["speech_end"] - select_labels[i]["speech_begin"] > 300:  # 超过5分钟的被裁剪成5分钟
        #         select_labels[i]["speech_end"] = select_labels[i]["speech_begin"] + 300
        # select_labels = [item for item in select_labels if select_labels.index(item) < 10]  # 只取前10个区间
        return select_labels
