#!/usr/bin/env python3

# builds a Pegasus subdax for a connector plugin and saves to an XML file

import sys
import os
import yaml
from yaml import FullLoader
import json
import base64

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding

from Pegasus.api import *

from geoedfengine.helper.GeoEDFConnector import GeoEDFConnector
from geoedfengine.helper.WorkflowUtils import WorkflowUtils

# process command line arguments:
# workflow filepath
# plugin identifier separated by :
# plugin name to be used as Conda environment name
# subdax filepath
# remote job directory (prefix)
# dependant vars as comma separated string
# stage references as comma separated string
# arg to local file bindings encoded as JSON string
# sensitive arg bindings
# stage references with dir modifier applied to them
# public key file
# dep var result files if any
# stage ref result files if any

# basic validation on number of args; additional ones for connector
if len(sys.argv) < 12:
    raise Exception("Insufficient arguments to connector plugin subdax construction job")

# extract the args
workflow_filename = str(sys.argv[1])
workflow_stage = str(sys.argv[2])
plugin_id = str(sys.argv[3])
plugin_name = str(sys.argv[4])
subdax_filename = str(sys.argv[5])
job_dir = str(sys.argv[6])
dep_vars_str = str(sys.argv[7])
stage_refs_str = str(sys.argv[8])
local_file_binds_str = str(sys.argv[9])
sensitive_arg_binds_str =  str(sys.argv[10])
dir_mod_refs_str = str(sys.argv[11])
workflow_pubkey_filename = str(sys.argv[12])

filter_var = ''

# determine the kind of plugin
if ':' in plugin_id:
    plugin_type = 'Filter'
    filter_var = plugin_id.split(':')[1]
else:
    plugin_type = 'Input'

# initialize DAX, set up some basic executables
if ':' in plugin_id:
    conn_plugin_wf = Workflow("stage-%s-Filter-%s" % (workflow_stage,plugin_id.split(':')[1]))
else:
    conn_plugin_wf = Workflow("stage-%s-Input" % workflow_stage)

# add the workflow file to this DAX
workflow_fname = os.path.split(workflow_filename)[1]
workflow_dax_file = File(workflow_fname)
rc = ReplicaCatalog().add_replica("local",workflow_fname,workflow_filename)

# get a helper
helper = WorkflowUtils()

# process the local file arg bindings JSON
if local_file_binds_str != 'None':
    local_file_args_exist = True
    local_file_binds = json.loads(local_file_binds_str)
    
    # create comma separated string of local file args and list of filepath vals
    local_file_args = list(local_file_binds.keys())

    local_files_needed = []
    for arg in local_file_args:
        local_files_needed.append(local_file_binds[arg])

    local_file_args_str = '%s' % local_file_args[0]

    for arg in local_file_args[1:]:
        local_file_args_str = '%s,%s' % (local_file_args_str,arg)
else:
    local_file_args_exist = False
    local_file_args_str = 'None'

# if local file args exist, add input files to DAX
local_dax_files = []
if local_file_args_exist:
    for local_file in local_files_needed:
        local_file_fname = os.path.split(local_file)[1]
        local_dax_file = File(local_file_fname)
        rc.add_replica("local",local_file_fname,local_file)
        local_dax_file.addPFN(PFN("file://%s" % local_file,"local"))
        local_dax_files.append(local_dax_file)

# process sensitive arg bindings if any, encrypting them with the public key
# create a new JSON dictionary with arg to encrypted value pairs
if sensitive_arg_binds_str != 'None':
    sensitive_arg_binds = json.loads(sensitive_arg_binds_str)
    encrypted_arg_binds = dict()

    # process these bindings, encrypting their value
    #workflow_pubkey_filename = '%s/output/public.pem' % run_dir

    with open(workflow_pubkey_filename,'rb') as key_file:
        public_key = serialization.load_pem_public_key(
            key_file.read(),
            backend=default_backend())

        for arg in sensitive_arg_binds:
            sensitive_val = sensitive_arg_binds[arg]
            sensitive_val_bytes = bytes(sensitive_val,'utf-8')
            encrypted_val = public_key.encrypt(
                sensitive_val_bytes,
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None))

            # use base64 encoding to enable this to be transmitted via JSON
            b64encoded_val = base64.encodebytes(encrypted_val)
            encrypted_arg_binds[arg] = b64encoded_val.decode('ascii')
            
    encrypted_arg_binds_str = json.dumps(json.dumps(encrypted_arg_binds))

else:
    encrypted_arg_binds_str = 'None'

# process stage references that have dir modifiers applied
# these references need to be bound only once
if dir_mod_refs_str != 'None':
    dir_modified_refs = dir_mod_refs_str.split(',')
else:
    dir_modified_refs = []
    
# extract the workflow stage
with open(workflow_filename,'r') as workflow_file:
    try:
        workflow_dict = yaml.load(workflow_file,Loader=FullLoader)

        stage_id = '$%s' % workflow_stage

        # process vars and stage references for this plugin
        # reconvert back to list
        # determine output files holding their values
        # and create a dictionary of variable/stage to value bindings
        dep_vars_exist = False
        stage_refs_exist = False

        # index for rest of result file args
        res_file_indx = 1
        
        if dep_vars_str != 'None':
            dep_vars_exist = True
            dep_vars = dep_vars_str.split(',')
            dep_var_values = dict()
    
            for var in dep_vars:
                dep_var_values[var] = []
                var_val_filename = str(sys.argv[12+res_file_indx])
                #var_val_filename = '%s/output/results_%s_%s.txt' % (run_dir,workflow_stage,var)
                with open(var_val_filename,'r') as var_val_file:
                    for val in var_val_file:
                        # strip trailing newline
                        dep_var_values[var].append(val.rstrip())
                res_file_indx += 1

        if stage_refs_str != 'None':
            stage_refs_exist = True
            stage_refs = stage_refs_str.split(',')
            stage_ref_values = dict()
            for stage_ref in stage_refs:
                stage_ref_values[stage_ref] = []
                #stage_ref_val_filename = '%s/output/results_%s.txt' % (run_dir,stage_ref)
                stage_ref_val_filename = str(sys.argv[12+res_file_indx])
                with open (stage_ref_val_filename,'r') as stage_ref_val_file:
                    for val in stage_ref_val_file:
                        stage_ref_values[stage_ref].append(val.rstrip())
                # if this stage ref has a dir modifier applied to it, only retain one value
                if stage_ref in dir_modified_refs:
                    stage_ref_values[stage_ref] = stage_ref_values[stage_ref][:1]
                res_file_indx += 1
                            

        # create binding combos from the values and stage refs
        if dep_vars_exist and stage_refs_exist:
            binding_combs = helper.create_binding_combs(dep_var_values,stage_ref_values)

            # loop through binding_combs, creating parallel jobs
            # convert binding to JSON string to send to job
            indx = 0
            plugin_jobs = []
            res_files = []
            for binding in binding_combs:
                var_binds = binding[0]
                stage_binds = binding[1]

                var_binds_str = json.dumps(json.dumps(var_binds))
                stage_binds_str = json.dumps(json.dumps(stage_binds))

                # if this is a filter, we need to provide an output filepath
                # filter outputs will be saved there
                if plugin_type == 'Filter':
                    res_filename = '%s/results_%s_%s_%d.txt' % (job_dir,workflow_stage,filter_var,indx)
                    res_file = File(res_filename)
                    res_files.append(res_file)

                # create job for this plugin
                exec_name = "run-connector-plugin-%s" % plugin_name.lower()
                plugin_job = Job(exec_name)

                # args:
                # workflow_file
                # workflow_stage (encoded with :)
                # output_path
                # var binds JSON str
                # stage refs JSON str
                # encrypted arg binds str
                # local file args str
                # optional trailing list of file args (local files converted to inputs)
                plugin_job.add_args(workflow_dax_file)
                plugin_job.add_inputs(workflow_dax_file)
                complete_workflow_stage = '%s:%s' % (workflow_stage,plugin_id)
                plugin_job.add_args(complete_workflow_stage)
            
                if plugin_type == 'Filter':
                    plugin_job.add_args("Filter")
                    plugin_job.add_args(res_file)
                else:
                    plugin_job.add_args("Input")
                    output_dir = '%s/%s' % (job_dir,workflow_stage)
                    plugin_job.add_args(output_dir)

                plugin_job.add_args(var_binds_str)
                plugin_job.add_args(stage_binds_str)
                plugin_job.add_args(encrypted_arg_binds_str)
                plugin_job.add_args(local_file_args_str)

                for local_dax_file in local_dax_files:
                    plugin_job.add_args(local_dax_file)
                    
                conn_plugin_wf.add_jobs(plugin_job)
                plugin_jobs.append(plugin_job)

                indx += 1
                
        # if just one of vars or stage refs is needed
        elif dep_vars_exist or stage_refs_exist:
            if dep_vars_exist:
                binding_combs = helper.create_binding_combs(dep_var_values,None)
            else:
                binding_combs = helper.create_binding_combs(stage_ref_values,None)

            indx = 0
            plugin_jobs = []
            res_files = []
            for binding in binding_combs:
                if dep_vars_exist:
                    var_binds_str = json.dumps(json.dumps(binding))
                    stage_binds_str = 'None'
                else:
                    var_binds_str = 'None'
                    stage_binds_str = json.dumps(json.dumps(binding))

                # if this is a filter, we need to provide an output filepath
                # filter outputs will be saved there
                if plugin_type == 'Filter':
                    res_filename = '%s/results_%s_%s_%d.txt' % (job_dir,workflow_stage,filter_var,indx)
                    res_file = File(res_filename)
                    res_files.append(res_file)

                # create job for this plugin
                exec_name = "run-connector-plugin-%s" % plugin_name.lower()
                plugin_job = Job(exec_name)

                # args:
                # workflow_file
                # workflow_stage (encoded with :)
                # output_path
                # var binds JSON str
                # stage refs JSON str
                # encrypted arg binds str
                # local file args str
                # optional trailing list of file args (local files converted to inputs)
                plugin_job.add_args(workflow_dax_file)
                plugin_job.add_inputs(workflow_dax_file)
                
                complete_workflow_stage = '%s:%s' % (workflow_stage,plugin_id)
                plugin_job.add_args(complete_workflow_stage)
            
                if plugin_type == 'Filter':
                    plugin_job.add_args("Filter")
                    plugin_job.add_args(res_file)
                else:
                    output_dir = '%s/%s' % (job_dir,workflow_stage)
                    plugin_job.add_args("Input")
                    plugin_job.add_args(output_dir)
                
                plugin_job.add_args(var_binds_str)
                plugin_job.add_args(stage_binds_str)
                plugin_job.add_args(encrypted_arg_binds_str)
                plugin_job.add_args(local_file_args_str)

                for local_dax_file in local_dax_files:
                    plugin_job.add_args(local_dax_file)
                    
                conn_plugin_wf.add_jobs(plugin_job)
                plugin_jobs.append(plugin_job)

                indx += 1

        else: # no bindings
            plugin_jobs = []
            res_files = []
            
            stage_binds_str = 'None'
            var_binds_str = 'None'

            # if this is a filter, we need to provide an output filepath
            # filter outputs will be saved there
            if plugin_type == 'Filter':
                res_filename = '%s/results_%s_%s_0.txt' % (job_dir,workflow_stage,filter_var)
                res_file = File(res_filename)
                res_files.append(res_file)

            # create job for this plugin
            exec_name = "run-connector-plugin-%s" % plugin_name.lower()
            plugin_job = Job(exec_name)

            # args:
            # workflow_file
            # workflow_stage (encoded with :)
            # output_path
            # var binds JSON str
            # stage refs JSON str
            # encrypted arg binds str
            # local file args str
            # optional trailing list of file args (local files converted to inputs)
            plugin_job.add_args(workflow_dax_file)
            plugin_job.add_inputs(workflow_dax_file)
            
            complete_workflow_stage = '%s:%s' % (workflow_stage,plugin_id)
            plugin_job.add_args(complete_workflow_stage)
            
            if plugin_type == 'Filter':
                plugin_job.add_args("Filter")
                plugin_job.add_args(res_file)
            else:
                output_dir = '%s/%s' % (job_dir,workflow_stage)
                plugin_job.add_args("Input")
                plugin_job.add_args(output_dir)
                
            plugin_job.add_args(var_binds_str)
            plugin_job.add_args(stage_binds_str)
            plugin_job.add_args(encrypted_arg_binds_str)
            plugin_job.add_args(local_file_args_str)

            for local_dax_file in local_dax_files:
                plugin_job.add_args(local_dax_file)
                    
            conn_plugin_wf.add_jobs(plugin_job)
            plugin_jobs.append(plugin_job)

        # merge results (for filters)
        if plugin_type == 'Filter':
            merge_job = Job("merge.py")
            merge_job.add_args(job_dir)
            merge_job.add_args(workflow_stage)
            merge_job.add_args(filter_var)
            for res_file in res_files:
                merge_job.add_args(res_file)
            merged_res_filename = 'results_%s_%s.txt' % (workflow_stage,filter_var)
            merged_res_file = File(merged_res_filename)
            merge_job.add_outputs(merged_res_file,register_replica=False)
            conn_plugin_wf.add_jobs(merge_job)
            conn_plugin_wf.add_dependency(merge_job,parents=plugin_jobs)

        # collect output names (for input)
        if plugin_type == 'Input':
            collect_job = Job("collect.py")
            output_dir = '%s/%s' % (job_dir,workflow_stage)
            collect_job.add_args(workflow_stage)
            collect_job.add_args(output_dir)
            collect_res_filename = 'results_%s.txt' % workflow_stage
            collect_res_file = File(collect_res_filename)
            collect_job.add_outputs(collect_res_file,register_replica=False)
            conn_plugin_wf.add_jobs(collect_job)
            conn_plugin_wf.add_dependency(collect_job,parents=plugin_jobs)

        # write replica catalog
        rc.write()
        conn_plugin_wf.add_replica_catalog(rc)
        
        # write out to DAX xml file
        conn_plugin_wf.write(subdax_filename)

    except:
        raise

