#!/usr/bin/env python

import os
import sys
import logging
import argparse
from multiprocessing import Pool
import pandas as pd
import anndata as ad
import beret

logging.basicConfig(level=logging.INFO,
                     format='%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n',
                     datefmt='%a, %d %b %Y %H:%M:%S',
                     stream=sys.stderr,
                     filemode="w"
                     )
error   = logging.critical
warn    = logging.warning
debug   = logging.debug
info    = logging.info


def get_input_parser():
    """Get the input data"""
    print('  \n~~~beretCount~~~')
    print('-Utility to perform sgRNA and reporter count from CRISPR base editors-')
    print(r'''
          )                                             )
         (           ________________________          (
        __)__       | __   __            ___ |        __)__
     C\|     \      |/  ` /  \ |  | |\ |  |  |     C\|     \
       \     /      |\__, \__/ \__/ | \|  |  |       \     /
        \___/       |________________________|        \___/
    ''')
    
    
    print('\n[Luca Pinello 2017, Jayoung Ryu 2021, send bugs, suggestions or *green coffee* to jayoung_ryu AT g DOT harvard DOT edu]\n\n')
    
    
    parser = argparse.ArgumentParser(description='CRISPRessoCount parameters',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-i','--input',  type=str,  help='List of fastq and sample ids. Formatted as R1_filepath, R2_filepath, sample_id', required=True)
    parser.add_argument('-t','--threads',  type=int,  help='Number of threads', default = 10)
    parser.add_argument('-b', '--edited_base', type = str, required = True, help = 'For base editors, the base that should be ignored when matching the gRNA sequence')
    parser.add_argument('-f','--sgRNA_filename', type=str, required = True, help='''sgRNA description file. The format requires three columns: gRNA, Reporter, gRNA_barcode.''')

    #optional
    parser.add_argument('--guide_start_seq', type = str, help = "Guide starts after this sequence in R1", default = "GGAAAGGACGAAACACCG")
    parser.add_argument('-r', '--count_reporter', help = "Count reporter edits.", action = 'store_true')
    parser.add_argument('-q','--min_average_read_quality', type=int, help='Minimum average quality score (phred33) to keep a read', default=30)
    parser.add_argument('-s','--min_single_bp_quality', type=int, help='Minimum single bp score (phred33) to keep a read', default=0)
    parser.add_argument('-n','--name',  help='Output name', default='')
    parser.add_argument('-o','--output_folder',  help='', default='')
    parser.add_argument('-l', '--reporter_length', type = int, help = "length of the reporter", default = 32)
    parser.add_argument('--keep_intermediate',help='Keep all the  intermediate files',action='store_true')
    parser.add_argument('--qstart_R1', help='Start position of the read when filtering for quality score of the read 1', type = int, default = 0)
    parser.add_argument('--qend_R1', help = 'End position of the read when filtering for quality score of the read 1', type = int, default = 47)
    parser.add_argument('--qstart_R2', help = 'Same as qstart_R1, for read 2 fastq file', default = 0)
    parser.add_argument('--qend_R2', help = 'Same as qstart_R2, for read 2 fastq file', default = 36)
    parser.add_argument('--gstart_reporter', help = "Start position of the guide sequence in the reporter", type = int, default = 6)
    parser.add_argument('--match_target_pos', help = "Only count the edit in the exact target position.", action = 'store_true')
    parser.add_argument('--guide_bc', help = 'Construct has guide barcode', default = True)
    parser.add_argument('--guide_bc_len', help = 'Guide barcode sequence length at the beginning of the R2', type = str, default = 4)
    parser.add_argument('--offset', help = 'Guide file has offest column that will be added to the relative position of reporters.', action = 'store_true')
    parser.add_argument('--align_fasta', help = 'gRNA is aligned to this sequence to infer the offset. Can be used when the exact offset is not provided.', type = str, default = '')
    parser.add_argument('-a', '--count_allele', help = 'count gRNA alleles', action='store_true')
    parser.add_argument('-g', "--count_guide_edits", help = "count the self editing of guides", action = 'store_true')
    parser.add_argument('-m', "--count_guide_reporter_alleles", help = "count the matched allele of guide and reporter edit", action = 'store_true')
    parser.add_argument('--rerun', help = 'Recount each sample', action='store_true')

    return(parser)

def count_sample(R1, R2, sample_id, args):
    args_dict = vars(args)
    args_dict["R1"] = R1
    args_dict["R2"] = R2
    args_dict["name"] = sample_id
    args_dict["output_folder"] = os.path.join(args.output_folder, sample_id)
    base_editing_map = {"A":"G", "C":"T"}
    edited_from = args_dict["edited_base"]
    edited_to = base_editing_map[edited_from]
    match_target_pos = args_dict["match_target_pos"]

    counter = beret.pp.GuideEditCounter(**args_dict)
    if os.path.exists("{}.h5ad".format(counter.output_dir)) and not args_dict["rerun"]:
        screen = beret.read_h5ad("{}.h5ad".format(counter.output_dir))
        if counter.count_reporter_edits:
            pass
            #screen.get_edit_mat_from_uns(edited_from, edited_to, match_target_pos)
        info("Reading already existing data for {} from \n\
            {}.h5ad".format(sample_id, counter.output_dir))

    else:
        info("Counting {}".format(sample_id))
        counter.check_filter_fastq()
        counter.get_counts()
        counter.screen.write("{}.h5ad".format(counter.output_dir))
        screen = counter.screen
        if counter.count_reporter_edits:
            pass
            #screen.get_edit_mat_from_uns(edited_from, edited_to, match_target_pos)
        info("Done for {}. \n\
            Output written at {}.h5ad".format(sample_id, counter.output_dir))

    return(screen)



if __name__ == '__main__':
    parser = get_input_parser()
    args = parser.parse_args()
    #args = check_arguments(args)

    sample_tbl = pd.read_csv(args.input, header = None)
    if len(sample_tbl[2].unique()) != len(sample_tbl[2]):
        print("Sample ID not unique. Please check your input file. Exiting.")
        exit(1)
    
    p = Pool(processes = args.threads)
    result = p.starmap(count_sample, [list(tup) + [args] for tup in list(
        sample_tbl.to_records(index = False))])
    #result = p.starmap(count_sample, sample_tbl[0], sample_tbl[1], sample_tbl[2])
    p.close()

    screen = beret.concat(result, axis = 1)
    if args.name: database_id = args.name
    else: database_id = args.input.split('.')[0]
    output_path = os.path.join(os.path.abspath(args.output_folder),
            "beret_count_%s" % database_id)

    guide_info_df = pd.read_csv(args.sgRNA_filename)
    try:
        screen.guides = result[0].guides.loc[screen.guides.index,:].reset_index()
    except:
        print(result[0])
    screen.condit = screen.condit.reset_index()
    screen.write("{}.h5ad".format(output_path))
    #screen.to_Excel("{}.xlsx".format(output_path))


    info('All Done!')
    print(r'''
          )                                             )
         (           ________________________          (
        __)__       | __   __            ___ |        __)__
     C\|     \      |/  ` /  \ |  | |\ |  |  |     C\|     \
       \     /      |\__, \__/ \__/ | \|  |  |       \     /
        \___/       |________________________|        \___/
    ''')
    sys.exit(0)