#!/usr/bin/env python

# executing script for multithreaded prediction from command line

'''
Specifically for the 'predict_disorder_domains'
functionality.

Jeff wanted disorder predictions to be faster.
I made this for Jeff.
You're welcome, Jeff.
'''


# import stuff for making CLI
import os
import argparse
import protfasta
import multiprocessing

# for the predictions...
from metapredict.backend.meta_tools import split_fasta
import metapredict as meta

if __name__ == "__main__":

    # Parse command line arguments.
    parser = argparse.ArgumentParser(description='Predict IDRs for all sequences in a FASTA file. Uses multiple cores for speeeeeeed!')

    parser.add_argument('data_file', help='Path to fasta file containing sequences to be predicted.')

    parser.add_argument('number_cores', help='Number of cores to run simultaneously. Do not specify more than you have..', type=int)

    parser.add_argument('-o', '--output-file', help='Filename for where to save the outputfile. Default = idrs_shephard.tsv', default='idrs_shephard.tsv')

    parser.add_argument('-l', '--legacy', action='store_true', help='Optional. Use this flag to use the original legacy version of metapredict.')

    parser.add_argument('--invalid-sequence-action', help="For parsing FASTA file, defines how to deal with non-standard amino acids. See https://protfasta.readthedocs.io/en/latest/read_fasta.html for details. Default='convert' ", default='convert')

    parser.add_argument('--mode', help='Defines the mode in which IDRs are reported. By default this generates a FASTA file with header format that matches  the input file with an additional set of fields that are "IDR_START=$START   IDR_END=$END" where $START and $END are the starting and ending IDRs. If mode is set to shephard-domains than a SHEPHAD-compliant domains file is generated. If shephard-domains-uniprot the uniprot ID is extracted from the header assuming standard uniprot formatting. Default = fasta', default='fasta')

    parser.add_argument('--threshold', help='Defines the threshold used to define a region as disordered or not. Default=0.42 which is recommended.', default=0.42, type=float)
    parser.add_argument('--verbose', help='If included then prints out status updates', action='store_true')

    args = parser.parse_args()

    if args.mode not in ['fasta', 'shephard-domains','shephard-domains-uniprot', ]:
        raise Exception("--mode must be set to one of 'fasta', 'shephard-domains', or 'shephard-domains-uniprot'")
    
    if args.legacy:
        use_legacy = True
        threshold_val = args.threshold_val

    else:
        use_legacy=False
        # if not using legacy and the default legacy value is still being used, adjust it to 0.5.
        if args.threshold == 0.42:
            threshold_val = 0.5

        # if the user sets their own threshold value that isn't 0.42, keep it.
        else:
            threshold_val = args.threshold
    
    if not os.path.isfile(args.data_file):
        print(f'Error: Could not find passed fasta file [{args.data_file:s}]')

    # read in sequences
    sequences = protfasta.read_fasta(args.data_file, invalid_sequence_action=args.invalid_sequence_action, return_list=True)
    if args.verbose:
        print('Read in FASTA file')

    # make number cores an int
    num_cores_used = args.number_cores

    # split sequences into sublists for multithreading
    split_seqs = split_fasta(sequences, num_cores_used)

    # create the file
    with open(args.output_file, 'w') as fh:
        pass

    # first split fasta into list of lists.
    protfasta_seqs = protfasta.read_fasta(args.data_file, 
        invalid_sequence_action=args.invalid_sequence_action, return_list=True)
    
    # split into list of lists where each list contains lists 
    # where that list 0=header, 1=sequence
    split_fasta_seqs = split_fasta(protfasta_seqs, number_splits=num_cores_used)

    # iterate through batches
    cur_threads = []
    for sub_batch in range(0, num_cores_used):
        cur_seq_batch = split_fasta_seqs[sub_batch]
        #startup the thread
        t = multiprocessing.Process(target = meta.predict_for_multithread, args = (cur_seq_batch, args.output_file, args.mode, threshold_val, use_legacy))
        cur_threads.append(t)
        
    # run cur threds
    for thr in cur_threads:
        thr.start()

    # join threads
    for thr in cur_threads:
        thr.join()    
    
                
        

