#!/usr/bin/env python

# Created on Fri Jun 29 14:22:41 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, logging, logging.handlers, traceback, neoloop

currentVersion = neoloop.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Identify neo-TADs.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help='Output path.')

    # Input
    parser.add_argument('-H', '--hic', help='''Cooler URI.''')
    parser.add_argument('--assembly', help='''The assembled SV list outputed by assemble-complexSVs.''')
    
    # Algorithm
    parser.add_argument('-R', '--region-size', default=3000000, type=int,
                        help = '''The extended genomic span of SV break points.(bp)''')
    parser.add_argument('--balance-type', default='CNV', choices=['CNV', 'ICE'],
                        help = 'Normalization method.')
    parser.add_argument('--protocol', default='insitu', choices=['insitu', 'dilution'],
                        help='''Experimental protocol of your Hi-C.''')
    parser.add_argument('--window-size', default=2000000, type=int,
                        help='Window size for calculating DI.')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('--logFile', default = 'neotad.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('DEBUG')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)

        if args.balance_type == 'CNV':
            balance = 'sweight'
        else:
            balance = 'weight'
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output Path = {0}'.format(args.output),
                   '# SV assembly = {0}'.format(args.assembly),
                   '# Cooler URI = {0}'.format(args.hic),
                   '# Extended Genomic Span = {0}bp'.format(args.region_size),
                   '# Weight column = {0}'.format(balance),
                   '# Experimental protocol = {0}'.format(args.protocol),
                   '# Window size = {0}'.format(args.window_size),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from joblib import Parallel, delayed
        from neoloop.tadtool.hmm import Genome
        import cooler, joblib
        
        rsize = args.region_size
        protocol = args.protocol

        try:
            '''
            logger.info('Training HMM using reference genome ...')
            G = Genome(args.hic, balance_type=balance, cpu_core=args.nproc)
            model = G.model
            joblib.dump(model, 'HMM-model.pkl')
            '''
            logger.info('Load the pre-trained HMM model')
            hmm_folder = os.path.join(os.path.split(neoloop.__file__)[0], 'data')
            model = joblib.load(os.path.join(hmm_folder, 'HMM-model.pkl'))

            # load Hi-C matrix
            clr = cooler.Cooler(args.hic)
            
            # load structural variations
            lines = {}
            with open(args.assembly, 'r') as source:
                for line in source:
                    parse = line.rstrip().split()
                    lines[parse[0]] = '\t'.join(parse[1:])

            logger.info('Detect TADs on the assemblies ...')
            Params = []
            for k in lines:
                Params.append((clr, k, lines, rsize, balance, protocol, args.window_size, model))
            
            results = Parallel(n_jobs=args.nproc, verbose=10)(delayed(pipeline)(*i) for i in Params)

            logger.info('Merge TADs across assemblies ...')
            cache = {}
            for tmp in results:
                if tmp is None:
                    continue
                for l in tmp:
                    if not l in cache:
                        cache[l] = []
                    cache[l].append(tmp[l])
            
            logger.info('Output ...')
            with open(args.output, 'w') as out:
                for k in sorted(cache):
                    trace = set(cache[k])
                    labels = []
                    for t in trace:
                        labels.append(','.join([t[0], str(t[1]), str(t[2])]))
                    label = ','.join(labels)
                    line = list(map(str, k))+[label]
                    out.write('\t'.join(line)+'\n')

            logging.info('Done')

        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)
    
def pipeline(clr, index, assemblies, rsize, balance, protocol, window_size, model):

    from neoloop.tadtool.core import TADcaller
    from neoloop.assembly import complexSV
    import numpy as np

    fu = complexSV(clr, assemblies[index], span=rsize, col=balance, protocol=protocol, flexible=True)
    fu.reorganize_matrix()
    if len(fu.Map) != fu.fusion_matrix.shape[0]:
        return # invalid assembly
        
    fu.correct_heterozygosity()
    fu.remove_gaps()

    if fu.gap_free.shape[0] < 20:
        return 
    
    core = TADcaller(fu.gap_free, clr.binsize, model, window_size)
    core.callDomains()
    tad_index = core.loop_like()
        
    tads = fu.index_to_coordinates(tad_index)
    final_dict = {}
    for k in tads:
        final_dict[k] = (index, tads[k][0], tads[k][1])

    return final_dict
    

if __name__ == '__main__':
    run()