SOURCE_DATA = [
    'CC1', 'CC2', 'HeLa1', 'HeLa2', 'MUX1',
]

ARRAY_DTYPES = {
    'signals': [2000, 'f4'],
    'scaling': [2, 'f8']
}

SEQUENCE_SUBTASKS_SPLIT = 300
SEQUENCE_SUBTASKS_NUMBERS = ['{:03d}'.format(i) for i in range(SEQUENCE_SUBTASKS_SPLIT)]
RANDOM_SEED = 922

OUTLIER_FILTERING_SAMPLE_SIZE = 10000
OUTLIER_CONTAMINATION = 0.02

TESTSET_SPLIT = 0.2
TRAINING_STDEV_BOOST = 1.8

KMERMODEL = '../../poreplex/kmer_models/r9.4_180mv_70bps_5mer_RNA/template_median69pA.model'

# Site specific prefix for the singularity environment
import subprocess
if subprocess.check_output('whoami').decode().strip() == 'hyeshik':
    shell.prefix('set -o pipefail; sing ')

rule all:
    input:
        expand('dataarrays/{type}-{run}.{sufx}',
               type=['signals', 'scaling'], run=SOURCE_DATA, sufx=['npy', 'size']),
        'dataarrays/signals.npy', 'dataarrays/scaling.npy'

rule split_reads_to_tasks:
    input: 'sources/{run}/sequencing_summary.txt'
    output:
        expand('readtbl/seqtasks-{{run}}/{taskno}.txt.gz', taskno=SEQUENCE_SUBTASKS_NUMBERS)
    run:
        import os
        import pandas as pd

        seqtbl = pd.read_table(input[0], low_memory=False)
        fast5pathfmt = os.path.abspath(os.path.dirname(input[0])) + '/fast5/{}/{}'
        seqtbl['filename'] = seqtbl.apply(lambda x:
            fast5pathfmt.format(x['label'], x['filename']), axis=1)
        seqtbl['binno'] = seqtbl.index.to_series() % SEQUENCE_SUBTASKS_SPLIT

        for binno, reads in seqtbl.groupby('binno'):
            outputfile = 'readtbl/seqtasks-{}/{:03d}.txt.gz'.format(wildcards.run, binno)
            reads.to_csv(outputfile, index=False, sep='\t', compression='gzip')

rule extract_signals_and_scales:
    input:
        seqtasks='readtbl/seqtasks-{run}/{taskno}.txt.gz',
        events='sources/{run}/events/inventory.h5'
    output:
        signaldump=temp('dumptmp/signals/{run}-{taskno}.dmp'),
        scalingdump=temp('dumptmp/scaling/{run}-{taskno}.dmp')
    threads: 2
    shell:
        'python3 scripts/extract-signals.py {input.seqtasks} {KMERMODEL} \
            {input.events} {output.signaldump} {output.scalingdump}'

rule merge_extracted_signals:
    input: expand('dumptmp/{{type}}/{{run}}-{taskno}.dmp', taskno=SEQUENCE_SUBTASKS_NUMBERS)
    output: temp('dataarrays/{type}-{run}.dmp')
    threads: 8
    run:
        sorted_input = sorted(input)
        shell('cat {sorted_input} > {output}')

rule convert_extracted_signals:
    input: 'dataarrays/{type}-{run}.dmp'
    output:
        npy='dataarrays/{type}-{run,[^.]+}.npy',
        size='dataarrays/{type}-{run}.size'
    run:
        import numpy as np
        arraydef = ARRAY_DTYPES[wildcards.type]
        arr = np.memmap(input[0], dtype=arraydef[1], mode='r', offset=0)
        assert arr.shape[0] % arraydef[0] == 0
        arr = arr.reshape([arr.shape[0] // arraydef[0], arraydef[0]])
        np.save(output.npy, arr)
        print(arr.shape[0], file=open(output.size, 'w'))

rule get_minimum_sample_size:
    input: expand('dataarrays/signals-{run}.size', run=SOURCE_DATA)
    output: 'dataarrays/sample-size-per-run.txt'
    shell: 'cat {input} | sort -n | head -1 > {output}'

rule subsample_for_balanced_weights:
    input:
        samplesize='dataarrays/sample-size-per-run.txt',
        signalfiles=expand('dataarrays/signals-{run}.npy', run=SOURCE_DATA),
        paramfiles=expand('dataarrays/scaling-{run}.npy', run=SOURCE_DATA)
    output:
        fullsignals='dataarrays/signals.npy',
        fullparams='dataarrays/scaling.npy'
    run:
        import numpy as np
        from numpy import random

        samplesize = int(open(input.samplesize).read().strip())
        signalfiles = sorted(input.signalfiles)
        paramfiles = sorted(input.paramfiles)

        selected_signal_arrays = []
        selected_scaling_arrays = []

        for sigf, paramf in zip(signalfiles, paramfiles):
            sigarr = np.load(sigf)
            paramarr = np.load(paramf)

            assert sigarr.shape[0] == paramarr.shape[0]
            indices = np.arange(sigarr.shape[0])
            random.seed(RANDOM_SEED)
            random.shuffle(indices)
            selected_indices = sorted(indices[:samplesize])

            selected_signal_arrays.append(sigarr[selected_indices])
            selected_scaling_arrays.append(paramarr[selected_indices])

        np.save(output.fullsignals, np.vstack(selected_signal_arrays))
        np.save(output.fullparams, np.vstack(selected_scaling_arrays))

rule exclude_outliers:
    input:
        fullsignals='dataarrays/signals.npy',
        fullparams='dataarrays/scaling.npy'
    output:
        purified_signals='dataarrays/signals-purified.npy',
        purified_params='dataarrays/scaling-purified.npy'
    run:
        from sklearn.ensemble import IsolationForest
        from numpy import random
        from numpy.lib.format import open_memmap
        import numpy as np

        paramsarr = np.load(input.fullparams)
        subsample_indices = np.arange(paramsarr.shape[0])
        random.seed(RANDOM_SEED)
        random.shuffle(subsample_indices)
        subsample_indices = subsample_indices[:OUTLIER_FILTERING_SAMPLE_SIZE]
        params_subsample = paramsarr[subsample_indices]

        ifor = IsolationForest(max_samples=OUTLIER_FILTERING_SAMPLE_SIZE,
                               contamination=OUTLIER_CONTAMINATION)
        ifor.fit(params_subsample)
        is_inlier = ifor.predict(paramsarr) > 0

        np.save(output.purified_params, paramsarr[is_inlier])
        np.save(output.purified_signals, open_memmap(input.fullsignals)[is_inlier])

rule split_testing_set:
    input:
        signalfile='dataarrays/signals-purified.npy',
        paramsfile='dataarrays/scaling-purified.npy'
    output:
        signals_testing='dataarrays/signals-testing.npy',
        signals_training='dataarrays/signals-training.npy',
        params_testing='dataarrays/scaling-testing.npy',
        params_training='dataarrays/scaling-training.npy',
        params_transform_info='dataarrays/scaling-transform.txt'
    run:
        import numpy as np
        from numpy.lib.format import open_memmap
        from numpy import random

        signals = open_memmap(input.signalfile)
        sparams = open_memmap(input.paramsfile)
        total_size = sparams.shape[0]

        print("Planning splitting the data into training and testing")
        total_indices = np.arange(total_size)
        random.seed(RANDOM_SEED)
        random.shuffle(total_indices)
        training_size = int(total_size * (1 - TESTSET_SPLIT))
        training_indices = sorted(total_indices[:training_size])
        testing_indices = sorted(total_indices[training_size:])

        print(' - Total size:', total_size)
        print(' - Training set size:', len(training_indices))
        print(' - Testing set size:', len(testing_indices))

        print("Spliting")
        training_x = signals[training_indices]
        training_y = sparams[training_indices]
        testing_x = signals[testing_indices]
        testing_y = sparams[testing_indices]

        print("Setting target mean and stdev for the redispersed training data")
        target_mean = training_y.mean(axis=0)
        target_stdev = training_y.std(axis=0) * TRAINING_STDEV_BOOST

        print("Correcting the training set signals to the standard model")
        training_x_normed = (training_x * training_y[:, 0, np.newaxis] +
                             training_y[:, 1, np.newaxis])
        del training_x, signals

        print("Redistributing the expected outputs of the training set randomly")
        training_y_redist = np.vstack([
            random.normal(target_mean[0], target_stdev[0], len(training_y)),
            random.normal(target_mean[1], target_stdev[1], len(training_y))
        ]).T
        training_x_redist = ((training_x_normed - training_y_redist[:, 1, np.newaxis]) /
                             training_y_redist[:, 0, np.newaxis])
        del training_y, training_x_normed

        print("Saving parameters for recovery of the outputs")
        transform_params = {
            'scale_mean': target_mean[0],
            'scale_std': target_stdev[0],
            'shift_mean': target_mean[1],
            'shift_std': target_stdev[1],
        }
        print(repr(transform_params), file=open(output.params_transform_info, 'w'))

        print("Balance the variances of output variables")
        training_y_final = np.vstack([
            (training_y_redist[:, 0] - target_mean[0]) / target_stdev[0],
            (training_y_redist[:, 1] - target_mean[1]) / target_stdev[1],
        ]).T
        testing_y_final = np.vstack([
            (testing_y[:, 0] - target_mean[0]) / target_stdev[0],
            (testing_y[:, 1] - target_mean[1]) / target_stdev[1],
        ]).T

        print("Saving the transformed values")
        np.save(output.signals_training, training_x_redist[:, :, np.newaxis])
        np.save(output.signals_testing, testing_x[:, :, np.newaxis])
        np.save(output.params_training, training_y_final)
        np.save(output.params_testing, testing_y_final)

rule make_mini_dataset: # for development
    input: 'dataarrays/{name}.npy'
    output: 'dataarrays/{name}.mini.npy'
    run:
        from numpy.lib.format import open_memmap
        import numpy as np
        np.save(output[0], open_memmap(input[0])[:10000])

