#!/usr/bin/env python

# Created on Tue Sep 8 20:45:18 2020
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, traceback, neoloop, os, subprocess

currentVersion = neoloop.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Calculate the copy number variation profile
                                     from Hi-C map using a generalized additive model with the
                                     Poisson link function''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('-H', '--hic', help='''Cool URI.''')
    parser.add_argument('-g', '--genome', default='hg38', choices = ['hg19', 'hg38', 'mm9', 'mm10'],
                       help='''Reference genome name.''')
    parser.add_argument('-e', '--enzyme', default='MboI', choices = ['HindIII', 'MboI', 'DpnII', 'Arima', 'BglII', 'uniform'],
                        help='''The restriction enzyme name used in the Hi-C experiment. If the genome
                        was cutted using a sequence-independent/uniform-cutting enzyme, please consider
                        setting this parameter to "uniform".''')
    parser.add_argument('--output', help='''Output file path.''')

    parser.add_argument('--cachefolder', default='.cache', help='''Cache folder. We suggest keep this setting the same
                        between different runs.''')
    parser.add_argument('--logFile', default = 'cnv-calculation.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Cool URI = {0}'.format(args.hic),
                   '# Reference Genome = {0}'.format(args.genome),
                   '# Enzyme = {0}'.format(args.enzyme),
                   '# Output Path = {0}'.format(args.output),
                   '# Cache Folder = {0}'.format(args.cachefolder),
                   #'# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from neoloop.cnv import runcnv
        #from pygam import PoissonGAM, s
        import numpy as np
        from rpy2.robjects import numpy2ri, Formula
        from rpy2.robjects.packages import importr

        weblinks = {
            'hg38_mappability_100mer.1kb.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38_mappability_100mer.1kb.bw',
            'hg38.MboI.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.MboI.npz',
            'hg38.DpnII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.MboI.npz',
            'hg38.Arima.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.Arima.npz',
            'hg38.BglII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.BglII.npz',
            'hg38.uniform.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.uniform.npz',
            'hg38.HindIII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38.HindIII.npz',
            'hg38_1kb_GC.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg38_1kb_GC.bw',
            'hg19_mappability_100mer.1kb.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19_mappability_100mer.1kb.bw',
            'hg19.MboI.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.MboI.npz',
            'hg19.Arima.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.Arima.npz',
            'hg19.HindIII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.HindIII.npz',
            'hg19.DpnII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.MboI.npz',
            'hg19.BglII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.BglII.npz',
            'hg19.uniform.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19.uniform.npz',
            'hg19_1kb_GC.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/hg19_1kb_GC.bw',
            'mm10_mappability_100mer.1kb.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10_mappability_100mer.1kb.bw',
            'mm10_1kb_GC.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10_1kb_GC.bw',
            'mm10.Arima.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.Arima.npz',
            'mm10.BglII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.BglII.npz',
            'mm10.DpnII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.DpnII.npz',
            'mm10.HindIII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.HindIII.npz',
            'mm10.MboI.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.MboI.npz',
            'mm10.uniform.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm10.uniform.npz',
            'mm9_mappability_100mer.1kb.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9_mappability_100mer.1kb.bw',
            'mm9_1kb_GC.bw': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9_1kb_GC.bw',
            'mm9.Arima.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.Arima.npz',
            'mm9.BglII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.BglII.npz',
            'mm9.DpnII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.DpnII.npz',
            'mm9.HindIII.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.HindIII.npz',
            'mm9.MboI.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.MboI.npz',
            'mm9.uniform.npz': 'http://3dgenome.fsm.northwestern.edu/neoLoopFinder/mm9.uniform.npz'
        }

        cachefolder = os.path.abspath(os.path.expanduser(args.cachefolder))
        if not os.path.exists(cachefolder):
            os.makedirs(cachefolder)
        
        mapscore_fil = os.path.join(cachefolder, '{0}_mappability_100mer.1kb.bw'.format(args.genome))
        cutsites_fil = os.path.join(cachefolder, '{0}.{1}.npz'.format(args.genome, args.enzyme))
        gc_fil = os.path.join(cachefolder, '{0}_1kb_GC.bw'.format(args.genome))
        for fil in [mapscore_fil, cutsites_fil, gc_fil]:
            if not os.path.exists(fil):
                key = os.path.split(fil)[1]
                logger.info('{0} was not found in the cache folder, downloading ...'.format(key))
                command = ['wget', '-O', fil, '-L', weblinks[key]]
                subprocess.check_call(' '.join(command), shell=True)

        try:
            mgcv = importr('mgcv')
            stats = importr('stats')

            logger.info('Calculate the 1D coverage from Hi-C matrix ...')
            table, res = runcnv.get_marginals(args.hic)
            logger.info('Load GC content ...')
            table = runcnv.signal_from_bigwig(table, gc_fil, name='GC')
            logger.info('Load mappability scores ...')
            table = runcnv.signal_from_bigwig(table, mapscore_fil, name='Mappability')
            logger.info('Count the number of cut sizes for each bin ...')
            table = runcnv.count_REsites(table, cutsites_fil, res)
            logger.info('Filter out invalid bins ...')
            mask, filtered = runcnv.filterZeros(table)
            logger.info('Done')

            logger.info('Fitting a Generalized Additive Model with log link and Poisson distribution ...')
            fomula = Formula('Coverage ~ s(GC) + s(Mappability) + s(RE)')
            fomula.environment['Coverage'] = numpy2ri.numpy2rpy(filtered['Coverage'].values)
            fomula.environment['GC'] = numpy2ri.numpy2rpy(filtered['GC'].values)
            fomula.environment['Mappability'] = numpy2ri.numpy2rpy(filtered['Mappability'].values)
            fomula.environment['RE'] = numpy2ri.numpy2rpy(filtered['RE'].values)
            gam = mgcv.gam(fomula, family=stats.poisson(link='log'))
            #X = filtered[['GC', 'Mappability', 'RE']].values
            #y = filtered['Coverage'].values
            #gam = PoissonGAM(s(0)+s(1)+s(2), fit_intercept=True).fit(X, y)
            #gam.gridsearch(X, y)
            logger.info('Output residuals ...')
            #residuals = gam.deviance_residuals(X, y)
            rs = mgcv.residuals_gam(gam, type='working')
            residuals = numpy2ri.rpy2py(rs)
            residuals = residuals - residuals.min()
            #residuals = residuals / residuals.mean()

            idx = np.where(mask)[0]
            CNV = np.zeros(table.shape[0])
            CNV[idx] = residuals
            table['CNV'] = CNV
            bedgraph = table[['chrom', 'start', 'end', 'CNV']]
            bedgraph.to_csv(args.output, sep='\t', header=False, index=False)

            logger.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()