#!/usr/bin/env python

# Created on Tue Nov 12 19:05:11 2019
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, traceback, neoloop

currentVersion = neoloop.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Remove copy number variation effect from cancer
                                     Hi-C.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Input
    parser.add_argument('-H', '--hic', help='''Cool URI.''')
    
    # Algorithm
    parser.add_argument('--cnv-file',
                        help='''Copy number segmentation in bedGraph format.''')
    parser.add_argument('--mad-max', default=5, type=float, help='''Ignore bins from the
                        contact matrix using the 'MAD-max' filter. Increase the value of this
                        parameter to include more bins to be considered in matrix balancing.
                        Setting this parameter to 0 will ignore the 'MAD-max' filter.''')
    parser.add_argument('--min-nnz', default=10, type=int, help='''Ignore bins from the
                        contact matrix whose marginal number of nonzeros is less than
                        this number.''')
    parser.add_argument('--ignore-diags', default=1, type=int, help='''Number of diagonals
                        of the contact matrix to ignore. For example, 0 ignores nothing,
                        1 ignores the main diagonal, and 2 ignores diagonals (-1, 0, 1).''')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('-f', '--force', action = 'store_true',
                        help = '''When specified, overwrite the target bias vector, "sweight",
                        if it already exists.''')
    parser.add_argument('--logFile', default = 'cnv-norm.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Cooler URI = {0}'.format(args.hic),
                   '# CNV Profile = {0}'.format(args.cnv_file),
                   '# MAD-MAX = {0}'.format(args.mad_max),
                   '# Minimum Nonzeros = {0}'.format(args.min_nnz),
                   '# Number of Diagonals to Ignore = {0}'.format(args.ignore_diags),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from neoloop.cnv.loadcnv import binCNV
        from neoloop.cnv.correctcnv import matrix_balance
        import cooler
        import numpy as np

        try:
            hic_pool = cooler.Cooler(args.hic)
            if (not 'sweight' in hic_pool.bins()) or args.force:
                logger.info('Match CNV segmentation to matrix bins')
                bincnv = binCNV(args.cnv_file, hic_pool.binsize)
                bincnv.assign_cnv(args.hic)
                logger.info('Perform CNV-separate matrix balancing ...')
                matrix_balance(args.hic, nproc=args.nproc, mad_max=args.mad_max, min_nnz=args.min_nnz,
                               ignore_diags=args.ignore_diags)
            else:
                logger.info('sweight column already exists, skip')
            
            logger.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()