#!/usr/bin/python3

# builds a Pegasus subdax for a processor plugin and saves to an XML file

import sys
import os
import yaml
from yaml import FullLoader
import json

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding

from Pegasus.api import *

from geoedfengine.helper.GeoEDFProcessor import GeoEDFProcessor
from geoedfengine.helper.WorkflowUtils import WorkflowUtils

# process command line arguments:
# workflow filepath
# workflow stage
# plugin name to be used to identify right executable (from right container)
# subdax filepath
# remote job directory (prefix)
# stage references as comma separated string
# arg to local file bindings encoded as JSON string
# sensitive arg bindings
# stage references that have dir modifier applied to them
# public key file
# stage ref result files if any

# basic validation on number of args
if len(sys.argv) < 10:
    raise Exception("Insufficient arguments to processor plugin subdax construction job")

# extract the args
workflow_filename = str(sys.argv[1])
workflow_stage = str(sys.argv[2])
plugin_name = str(sys.argv[3])
subdax_filename = str(sys.argv[4])
job_dir = str(sys.argv[5])
stage_refs_str = str(sys.argv[6])
local_file_binds_str = str(sys.argv[7])
sensitive_arg_binds_str = str(sys.argv[8])
dir_mod_refs_str = str(sys.argv[9])
workflow_pubkey_filename = str(sys.argv[10])

# initialize DAX, set up some basic executables
proc_plugin_wf = Workflow("stage-%s-Processor" % workflow_stage)

# add the workflow file to this DAX
workflow_fname = os.path.split(workflow_filename)[1]
workflow_dax_file = File(workflow_fname)
rc = ReplicaCatalog().add_replica("local", workflow_fname, workflow_filename)

# get a helper
helper = WorkflowUtils()

# process the local file arg bindings JSON
if local_file_binds_str != 'None':
    local_file_args_exist = True
    local_file_binds = json.loads(local_file_binds_str)
    
    # create comma separated string of local file args and list of filepath vals
    local_file_args = list(local_file_binds.keys())

    local_files_needed = []
    for arg in local_file_args:
        local_files_needed.append(local_file_binds[arg])

    local_file_args_str = '%s' % local_file_args[0]

    for arg in local_file_args[1:]:
        local_file_args_str = '%s,%s' % (local_file_args_str,arg)
else:
    local_file_args_exist = False
    local_file_args_str = 'None'

# if local file args exist, add input files to DAX
local_dax_files = []
other_local_dax_files = []
if local_file_args_exist:
    for local_file in local_files_needed:
        local_file_fname = os.path.split(local_file)[1]
        local_dax_file = File(local_file_fname)
        rc.add_replica("local",local_file_fname, local_file)
        local_dax_files.append(local_dax_file)
        # special handling for shapefiles; also add other related files (dbf,shx,prj, etc.)
        # check if shapefile
        local_file_ext = os.path.splitext(local_file_fname)[1]
        if local_file_ext == '.shp':
            # find all other files with same basename
            basename = os.path.splitext(local_file_fname)[0]
            # assumes fully qualified path provided
            local_file_dir = os.path.split(local_file)[0]
            for local_dir_file in os.listdir(local_file_dir):
                full_path = '%s/%s' % (local_file_dir,local_dir_file)
                if os.path.isfile(full_path):
                    local_dir_file_ext = os.path.splitext(local_dir_file)[1]
                    if local_dir_file.startswith(basename) and local_dir_file_ext != '.shp':
                        # add these also to DAX as inputs
                        other_local_dax_file = File(local_dir_file)
                        rc.add_replica("local", local_dir_file, full_path)
                        other_local_dax_files.append(other_local_dax_file)

# process sensitive arg bindings if any, encrypting them with the public key
# create a new JSON dictionary with arg to encrypted value pairs
if sensitive_arg_binds_str != 'None':
    sensitive_arg_binds = json.loads(sensitive_arg_binds_str)
    encrypted_arg_binds = dict()

    # process these bindings, encrypting their value
    #workflow_pubkey_filename = '%s/output/public.pem' % run_dir

    with open(workflow_pubkey_filename,'rb') as key_file:
        public_key = serialization.load_pem_public_key(
            key_file.read(),
            backend=default_backend())

        for arg in sensitive_arg_binds:
            sensitive_val = sensitive_arg_binds[arg]
            sensitive_val_bytes = bytes(sensitive_val,'utf-8')
            encrypted_val = public_key.encrypt(
                sensitive_val_bytes,
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None))

            #base64 encode val so it can be transmitted via JSON
            b64encoded_val = base64.encodestring(encrypted_val)
            
            encrypted_arg_binds[arg] = b64encoded_val.decode('ascii')
            
    encrypted_arg_binds_str = json.dumps(json.dumps(encrypted_arg_binds))

else:
    encrypted_arg_binds_str = 'None'

# process stage references with dir modifier
if dir_mod_refs_str != 'None':
    dir_modified_refs = dir_mod_refs_str.split(',')
else:
    dir_modified_refs = []

# extract the workflow stage
with open(workflow_filename,'r') as workflow_file:
    try:
        workflow_dict = yaml.load(workflow_file,Loader=FullLoader)

        stage_id = '$%s' % workflow_stage

        # process stage references for this plugin (processors cannot have vars)
        # reconvert back to list
        # determine output files holding their values
        # and create a dictionary of stage to value bindings
        stage_refs_exist = False

        res_file_indx = 1
        
        if stage_refs_str != 'None':
            stage_refs_exist = True
            stage_refs = stage_refs_str.split(',')
            stage_ref_values = dict()
            for stage_ref in stage_refs:
                stage_ref_values[stage_ref] = []
                #stage_ref_val_filename = '%s/output/results_%s.txt' % (run_dir,stage_ref)
                stage_ref_val_filename = str(sys.argv[10+res_file_indx])
                with open (stage_ref_val_filename,'r') as stage_ref_val_file:
                    for val in stage_ref_val_file:
                        stage_ref_values[stage_ref].append(val.rstrip())
                # if this stage ref has a dir modifier applied to it, only retain one value
                if stage_ref in dir_modified_refs:
                    stage_ref_values[stage_ref] = stage_ref_values[stage_ref][:1]
                res_file_indx += 1

        # if stage refs exist, build binding combinations and corresponding parallel jobs
        if stage_refs_exist:

            binding_combs = helper.create_binding_combs(stage_ref_values,None)

            indx = 0
            plugin_jobs = []
            res_files = []
            for binding in binding_combs:

                stage_binds_str = json.dumps(json.dumps(binding))

                # create job for this plugin
                # executable name is different for each proc since
                # it needs to be run in the processor's own container
                exec_name = "run-processor-plugin-%s" % plugin_name.lower()
                plugin_job = Job(exec_name)

                # args:
                # workflow_file
                # workflow_stage
                # plugin type
                # output_path
                # stage refs JSON str
                # encrypted arg binds str
                # local file args str
                # optional trailing list of file args (local files converted to inputs)
                plugin_job.add_args(workflow_dax_file)
                plugin_job.add_inputs(workflow_dax_file)
                
                plugin_job.add_args(workflow_stage)
            
                output_dir = '%s/%s' % (job_dir,workflow_stage)
                plugin_job.add_args("Processor")
                plugin_job.add_args(output_dir)
                
                plugin_job.add_args(stage_binds_str)
                plugin_job.add_args(encrypted_arg_binds_str)
                plugin_job.add_args(local_file_args_str)

                for local_dax_file in local_dax_files:
                    plugin_job.add_args(local_dax_file)
                    plugin_job.add_inputs(local_dax_file)
                    
                for other_local_dax_file in other_local_dax_files:
                    plugin_job.add_inputs(other_local_dax_file)
                    
                proc_plugin_wf.add_jobs(plugin_job)
                plugin_jobs.append(plugin_job)

                indx += 1

        else: # no bindings
            plugin_jobs = []
            res_files = []
            
            stage_binds_str = 'None'

            # create job for this plugin
            # executable name is different for each proc since
            # it needs to be run in the processor's own container
            exec_name = "run-processor-plugin-%s" % plugin_name.lower()
            plugin_job = Job(exec_name)

            # args:
            # workflow_file
            # workflow_stage
            # plugin type
            # output_path
            # stage refs JSON str
            # encrypted arg binds str
            # local file args str
            # optional trailing list of file args (local files converted to inputs)
            plugin_job.add_args(workflow_dax_file)
            plugin_job.add_inputs(workflow_dax_file)
                
            plugin_job.add_args(workflow_stage)
            
            output_dir = '%s/%s' % (job_dir,workflow_stage)
            plugin_job.add_args("Processor")
            plugin_job.add_args(output_dir)
                
            plugin_job.add_args(stage_binds_str)
            plugin_job.add_args(encrypted_arg_binds_str)
            plugin_job.add_args(local_file_args_str)

            for local_dax_file in local_dax_files:
                plugin_job.add_args(local_dax_file)
                plugin_job.add_inputs(local_dax_file)
                    
            for other_local_dax_file in other_local_dax_files:
                plugin_job.add_inputs(other_local_dax_file)
                    
            proc_plugin_wf.add_jobs(plugin_job)
            plugin_jobs.append(plugin_job)

        # collect output file names
        collect_job = Job("collect.py")
        output_dir = '%s/%s' % (job_dir,workflow_stage)
        collect_job.add_args(workflow_stage)
        collect_job.add_args(output_dir)
        collect_res_filename = 'results_%s.txt' % workflow_stage
        collect_res_file = File(collect_res_filename)
        collect_job.add_outputs(collect_res_file,register_replica=False)
        proc_plugin_wf.add_jobs(collect_job)
        proc_plugin_wf.add_dependency(collect_job,parents=plugin_jobs)

        # write out replica catalog
        rc.write()
        proc_plugin_wf.add_replica_catalog(rc)
        
        # write out to DAX xml file
        proc_plugin_wf.write(subdax_filename)

    except:
        raise

