#!/usr/bin/env python

# Created on Tue Nov 12 19:05:11 2019
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, traceback, neoloop

currentVersion = neoloop.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Perform HMM segmentation on a pre-calculated
                                     copy number variation profile.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--cnv-file',
                        help='''Copy number profile in bedGraph format.''')
    parser.add_argument('--ploidy', type=int, default=2, help='Ploidy of the cell.')
    parser.add_argument('--binsize', type=int, help='Bin size in base pairs.')
    parser.add_argument('--num-of-states', type=int,
                        help='''The number of states to initialize. By default, this number
                        is estimated using a Gaussian Mixture model and the AIC criterion.''')
    parser.add_argument('--cbs-pvalue', type=float, default=1e-5,
                        help='''The P-value cutoff used in the circular binary segmentation (CBS) algorithm.''')
    parser.add_argument('--max-dist', type=int, default=4,
                        help='''Maximum allowed distance between HMM and CBS breakpoints.''')
    parser.add_argument('--min-diff', type=float, default=0.4,
                        help='''Minimum copy number difference between neighboring segments.
                        Two neighboring segments will be merged together if the difference of
                        their copy number ratio is less than this value.''')
    parser.add_argument('--min-segment-size', type=int, default=3,
                        help='''Minimum segment size. Small segments with a size less than
                        this number of bins will be merged with nearby larger segments.''')
    parser.add_argument('--output',
                        help='''Output file path.''')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('--logFile', default = 'cnv-seg.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# CNV Profile = {0}'.format(args.cnv_file),
                   '# Number of States = {0}'.format(args.num_of_states),
                   '# CBS P-value Cutoff = {0}'.format(args.cbs_pvalue),
                   '# Maximum HMM-CBS distance = {0}'.format(args.max_dist),
                   '# Minimum Copy Number difference = {0}'.format(args.min_diff),
                   '# Minimum Segment Size = {0}'.format(args.min_segment_size),
                   '# Ploidy = {0}'.format(args.ploidy),
                   '# Bin Size = {0}'.format(args.binsize),
                   '# Output Path = {0}'.format(args.output),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from neoloop.cnv.segcnv import HMMsegment

        try:
            logger.info('Loading CNV profile ...')
            work = HMMsegment(args.cnv_file,
                              res=args.binsize,
                              nproc=args.nproc,
                              ploidy=args.ploidy,
                              n_states=args.num_of_states)

            #logger.info('Perform segmentation ...')
            work.segment(
                min_seg=args.min_segment_size,
                min_diff=args.min_diff,
                max_dist=args.max_dist,
                p=args.cbs_pvalue
            )

            logger.info('Writing results into bedGraph ...')
            work.output(args.output)
            
            logger.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()