#!/usr/bin/env ganga

'''
Script used to submit fitering jobs using Ganga
'''
import os
import json
import argparse

from typing              import Union
from dataclasses         import dataclass
from importlib.resources import files

import yaml

from GangaCore.GPI             import Job, Executable, DiracFile, LocalFile, Interactive, GenericSplitter, Dirac, Local, File
from post_ap.pfn_reader        import PFNReader
from dmu.logging.log_store     import LogStore

log = LogStore.add_logger('submit_filter')

# pylint: disable=line-too-long,too-many-instance-attributes
# -------------------------------------------------
@dataclass
class Data:
    '''
    Class used to share data
    '''
    name     : str
    prod     : str
    samp     : str
    conf     : str
    venv     : str
    back     : str
    env_lfn  : str
    njob     : int
    test     : bool
    cfg      : dict

    maxj     = 500
    env_path = '/home/acampove/VENVS'
    user     = 'acampove'
    pfn_path = '/tmp/pfns.json'
    runner   = 'run_filter'
    platform = 'x86_64-centos7-gcc8-opt'
    grid_site= 'LCG.CERN.cern'
# -------------------------------------------------
def _initialize() -> None:
    Data.env_lfn = f'LFN:/lhcb/user/{Data.user[0]}/{Data.user}/run3/venv/{Data.venv}/dcheck.tar'
    with open(Data.conf, encoding='utf-8') as ifile:
        Data.cfg = yaml.safe_load(ifile)

    _save_pfns()
# -------------------------------------------------
def _pfns_to_njob(d_pfn : dict[str, list[str]]) -> int:
    npfn = 0
    for l_pfn in d_pfn.values():
        npfn += len(l_pfn)

    return npfn if npfn < Data.maxj else Data.maxj
# -------------------------------------------------
def _save_pfns() -> None:
    reader    = PFNReader(cfg=Data.cfg)
    d_pfn     = reader.get_pfns(production=Data.prod, nickname=Data.samp)
    Data.njob = _pfns_to_njob(d_pfn)

    log.info(f'Will use {Data.njob} jobs')

    with open(Data.pfn_path, 'w', encoding='utf-8') as ofile:
        json.dump(d_pfn, ofile, indent=4)
# -------------------------------------------------
def _get_splitter() -> GenericSplitter:
    splitter           = GenericSplitter()
    splitter.attribute = 'application.args'
    splitter.values    = _get_splitter_args()

    return splitter
# -------------------------------------------------
def _get_splitter_args() -> list[list]:
    njob = Data.njob
    # Regardless of how many jobs were specified
    # Will run only one job if:
    #
    # 1. It runs locally/interactively
    # 2. The test flag was passed
    if Data.back in ['Interactive', 'Local'] or Data.test:
        njob = 1

    conf = os.path.basename(Data.conf)

    return [ [Data.prod, Data.samp, conf, Data.njob, i_job, Data.env_path, Data.user] for i_job in range(njob) ]
# -------------------------------------------------
def _get_dirac_backend() -> Dirac:
    backend = Dirac()
    backend.settings['Destination'] = Data.grid_site

    return backend
# -------------------------------------------------
def _get_backend() -> Union[Dirac, Interactive]:
    if Data.back == 'Interactive':
        return Interactive()

    if Data.back == 'Local':
        return Local()

    if Data.back == 'Dirac':
        return _get_dirac_backend() 

    raise ValueError(f'Invalid backend: {Data.back}')
# -------------------------------------------------
def _parse_args() -> None:
    parser=argparse.ArgumentParser(description='Script used to send ntuple filtering jobs to the Grid, through ganga')
    parser.add_argument('-n' , '--name', type=str, help='Job name'           , required=True)
    parser.add_argument('-p' , '--prod', type=str, help='Production'         , required=True)
    parser.add_argument('-s' , '--samp', type=str, help='Sample'             , required=True)
    parser.add_argument('-c' , '--conf', type=str, help='Path to config file', required=True)
    parser.add_argument('-b' , '--back', type=str, help='Backend'            , choices=['Interactive', 'Local', 'Dirac'], default='Interactive')
    parser.add_argument('-t' , '--test',           help='Will run one job only if used'                       , action='store_true')
    parser.add_argument('-v' , '--venv', type=str, help='Version of virtual environment used to run filtering', required=True)
    args = parser.parse_args()

    Data.name = args.name
    Data.prod = args.prod
    Data.samp = args.samp
    Data.conf = args.conf
    Data.back = args.back
    Data.venv = args.venv
    Data.test = args.test
# -------------------------------------------------
def _get_executable() -> Executable:
    runner_path = files('post_ap_grid').joinpath(Data.runner)
    runner_path = str(runner_path)

    obj          = Executable()
    obj.exe      = File(runner_path)
    obj.platform = Data.platform

    return obj
# -------------------------------------------------
def _get_output_files() -> list[Union[DiracFile,LocalFile]]:
    if Data.back in ['Interactive', 'Local']:
        return [LocalFile('*.root')]

    if Data.back in ['Dirac']:
        return [DiracFile('*.root')]

    raise ValueError(f'Invalid backend: {Data.back}')
# -------------------------------------------------
def _get_job() -> Job:
    job              = Job(name = Data.name)
    job.application  = _get_executable()
    job.inputfiles   = [ LocalFile(Data.conf), LocalFile(Data.pfn_path), DiracFile(Data.env_lfn) ]
    job.splitter     = _get_splitter()
    job.backend      = _get_backend()
    job.outputfiles  = _get_output_files()

    return job
# -------------------------------------------------
def main():
    '''
    Script starts here
    '''
    _parse_args()
    _initialize()

    job=_get_job()
    job.prepare()
    job.submit()
# -------------------------------------------------
if __name__ == 'GangaCore.GPI':
    main()
