#!/usr/bin/env python

# Created on Fri Jun 29 14:22:41 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, logging, logging.handlers, traceback, neoloop

currentVersion = neoloop.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Identify novel loop interactions across SV
                                     points.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help='Path to the output file.')

    # Input
    parser.add_argument('-H', '--hic', nargs='+', help='''List of cooler URIs. If URIs at multiple
                        resolutions are provided, the program will first detect neo-loops from
                        each individual resolution, and then combine results from all resolutions
                        in a non-redundant way. If an interaction is detected as a loop in multiple
                        resolutions, only the one with the highest resolution will be recorded.''')
    parser.add_argument('--assembly', help='''The assembled SV list outputed by assemble-complexSVs.''')
    
    # Algorithm
    parser.add_argument('--cachefolder', default='.cache', help='''Path to a folder to place the pre-trained Peakachu models.
                        This command will automatically download appropriate models according to the
                        sequencing depths and resolutions of your input Hi-C matrices.''')
    parser.add_argument('-R', '--region-size', default=3000000, type=int,
                        help = '''The extended genomic span of SV break points.(bp)''')
    parser.add_argument('--balance-type', default='CNV', choices=['CNV', 'ICE'],
                        help = 'Normalization method.')
    parser.add_argument('--protocol', default='insitu', choices=['insitu', 'dilution'],
                        help='''Experimental protocol of your Hi-C.''')
    parser.add_argument('--prob', type=float, default=0.9,
                        help = 'Probability threshold.')
    parser.add_argument('--no-clustering', action = 'store_true',
                        help = 'No pooling will be performed if specified.')
    parser.add_argument('--min-marginal-peaks', type = int, default = 2,
                        help = '''Minimum marginal number of loops when detecting loop anchors.''')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('--logFile', default = 'neoloop.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output Path = {0}'.format(args.output),
                   '# Cache Folder = {0}'.format(args.cachefolder),
                   '# SV assembly = {0}'.format(args.assembly),
                   '# Cooler List = {0}'.format(args.hic),
                   '# Extended Genomic Span = {0}bp'.format(args.region_size),
                   '# Balance Type = {0}'.format(args.balance_type),
                   '# Probability threshold = {0}'.format(args.prob),
                   '# No pooling = {0}'.format(args.no_clustering),
                   '# Minimum marginal peaks = {0}'.format(args.min_marginal_peaks),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from joblib import Parallel, delayed
        from neoloop.callers import combine_annotations
        from neoloop.util import calculate_expected
        import cooler
        
        rsize = args.region_size
        prob, mmp = args.prob, args.min_marginal_peaks
        no_pool = args.no_clustering
        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False

        try:
            # load Hi-C matrix
            cools = {}
            for path in args.hic:
                lib = cooler.Cooler(path)
                res = lib.binsize
                cools[res] = path
            # download pre-trained models
            cachefolder = os.path.abspath(os.path.expanduser(args.cachefolder))
            if not os.path.exists(cachefolder):
                os.makedirs(cachefolder)
            models_by_res = download_peakachu_models(cachefolder, cools)

            # load structural variations
            logger.info('Load assembled SVs')
            lines = {}
            with open(args.assembly, 'r') as source:
                for line in source:
                    parse = line.rstrip().split()
                    lines[parse[0]] = '\t'.join(parse[1:])

            byres = {}
            for res in sorted(cools, reverse=True):
                logger.info('Current resolution: {0}'.format(res))
                clr = cooler.Cooler(cools[res])
                logger.info('Calculate the global average contact frequencies at each genomic distance ...')
                chroms = []
                for c in clr.chromnames:
                    if (not '_' in c) and (not 'M' in c) and (not 'X' in c) and \
                       (not 'Y' in c) and (not 'MT' in c) and (not 'EBV' in c):
                       chroms.append(c)
                expected = calculate_expected(clr, chroms, maxdis=5000000, balance=balance, nproc=args.nproc)
                logger.info('Done')
                Params = []
                for k in lines:
                    Params.append((clr, k, lines, rsize, balance, prob, mmp, no_pool, models_by_res, expected))

                results = Parallel(n_jobs=args.nproc, verbose=10)(delayed(pipeline)(*i) for i in Params)
                
                logger.info('Merge loops across assemblies ...')
                cache = {}
                for tmp in results:
                    if tmp is None:
                        continue
                    for l in tmp:
                        if not l in cache:
                            cache[l] = []
                        cache[l].append(tmp[l])
                byres[res] = cache
            
            if len(cools) > 1:
                logger.info('Combine loops from multiple resolutions')
            loop_list = combine_annotations(byres)
            logger.info('Output ...')
            with open(args.output, 'w') as out:
                for line in loop_list:
                    out.write('\t'.join(line)+'\n')

            logging.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)
    
def pipeline(clr, index, assemblies, rsize, balance, prob, mmp, nopool, models_by_res, expected):

    from neoloop.callers import Peakachu
    from neoloop.assembly import complexSV
    import numpy as np

    # force the program to extend the assembly ends by rsize
    # would be helpful to detect loops near assembly boundaries
    fu = complexSV(clr, assemblies[index], span=rsize, col=balance,
                   protocol='insitu', expected_values=expected, flexible=False)
    fu.reorganize_matrix()
    if len(fu.Map) != fu.fusion_matrix.shape[0]:
        return # invalid assembly
        
    fu.correct_heterozygosity()
    fu.remove_gaps()

    core = Peakachu(fu.gap_free, upper=400*fu.res, res=fu.res, protocol='insitu')
    core.load_expected(expected_values=expected)
    core.load_models(models_by_res[fu.res])
    loop_index = core.predict(thre=prob, no_pool=nopool, min_count=mmp, index_map=fu.index_map,
                            chains=fu.chains)
        
    loops = fu.index_to_coordinates(loop_index)
    # filter out loop calls outside of the original assembly
    fu = complexSV(clr, assemblies[index], span=rsize, col=balance,
                   protocol='insitu', expected_values=expected)
    fu.reorganize_matrix()
    final_dict = {}
    for k in loops:
        if ((k[0], k[1]) in fu.Map) and ((k[3], k[4]) in fu.Map):
            final_dict[k] = (index, loops[k][0], loops[k][1])

    return final_dict

def match_digital_values(arr, v):

    import numpy as np

    diff = np.abs(v - np.r_[arr])
    idx = np.argmin(diff)

    return arr[idx]


def download_peakachu_models(cachefolder, cools):

    import cooler, joblib, subprocess, time

    data_folder = os.path.join(os.path.split(neoloop.__file__)[0], 'data')
    links = joblib.load(os.path.join(data_folder, 'peakachu-model-links.pkl'))

    model_by_res = {}
    for res in cools:
        m_res = match_digital_values(list(links), res)
        # estimate sequencing depth from cool
        clr = cooler.Cooler(cools[res])
        if 'sum' in clr.info:
            seq_depth = clr.info['sum']
        else:
            if clr.info['nnz'] < 100000000:
                seq_depth = clr.pixels()[:]['count'].values.sum()
            else:
                seq_depth = clr.info['nnz'] * 2.14447683

        genome_size = clr.chromsizes.sum()
        seq_depth = 3031042417 / genome_size * seq_depth
        seq_depth = res / m_res * seq_depth
        intra_depth = seq_depth * 0.73256754

        m_depth = match_digital_values(list(links[m_res]), intra_depth)
        weblink = links[m_res][m_depth]
        ofil = os.path.join(cachefolder, 'peakachu_v2.res{0}.depth{1}.pkl'.format(m_res, m_depth))
        if not os.path.exists(ofil):
            command = ['wget', '-O', ofil, '-L', weblink]
            subprocess.check_call(' '.join(command), shell=True)
            time.sleep(3)
        model_by_res[res] = ofil
    
    return model_by_res

if __name__ == '__main__':
    run()