#!/usr/bin/env python

import numpy as np
import pylab as plt


def statefilter(thestates, minlength, minhold, debug=False):
    print("state filtering with length", minlength)
    thefiltstates = np.zeros((len(thestates)), dtype=int)
    currentstate = thestates[0]
    laststate = currentstate + 0
    currentlen = 1
    lastlen = 1
    for state in range(1, len(thestates)):
        if thestates[state] == currentstate:
            currentlen += 1
            thefiltstates[state] = thestates[state]
            if debug:
                print("state", state, "(", thestates[state], "):continue")
        else:
            if (currentlen < minlength) and (thestates[state] == laststate):
                thefiltstates[state - currentlen : state + 1] = laststate
                currentstate = laststate + 0
                currentlen += lastlen
                if debug:
                    print("state", state, "(", thestates[state], "):patch")
            elif (currentlen < minhold) and (thestates[state] != laststate):
                thefiltstates[state - currentlen : state + 1] = laststate
                currentstate = laststate + 0
                currentlen += lastlen
                if debug:
                    print("state", state, "(", thestates[state], "):fill")
            else:
                lastlen = currentlen + 1
                currentlen = 1
                laststate = currentstate + 0
                currentstate = thestates[state]
                thefiltstates[state] = thestates[state]
                if debug:
                    print("state", state, "(", thestates[state], "):switch")
            """else:
                lastlen = currentlen + 1
                currentlen = 1
                laststate = currentstate + 0
                currentstate = thestates[state]
                thefiltstates[state] = thestates[state]
                if debug:
                    print('state', state, '(', thestates[state], '):shortswitch')
            """
    return thefiltstates


thestates = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 1, 2, 0, 1, 1, 1, 1, 2, 2, 1, 1, 1])

thefiltstates = statefilter(thestates, 2, 2, debug=1)
plt.plot(thestates)
plt.plot(thefiltstates + 1)
plt.show()
