# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/10_loop.ipynb (unless otherwise specified).

__all__ = ['method4all', 'StorageCore', 'Loop', 'ProgressBar', 'Tolerate', 'LambdaCall', 'Event']

# Cell
import numpy as np

# Cell
from tqdm import tqdm
from types import MethodType
import inspect
from functools import partial

def method4all(f):
    """
    Use this function as a decorator,
    The decorated function under Loop class can be used on outer layer
    """
    setattr(f,"forall",True)
    return f

class StorageCore:
    def __init__(self,start_layer):
        self.layers = []
        self.lmap = dict()
        self.forall_pool = dict()
        self.new_layer(start_layer)
        self.i = -1
        self.epoch = -1

    def new_layer(self,layer):
        layer.core = self
        self.layers.append(layer)
        self.seq_layers(self.layers)
        self.lmap[layer._loop_layer]=layer
        if hasattr(self,layer.name):
            raise KeyError(f"Layer name already exist: {layer.name}")
        setattr(self,layer.name,layer)
        self.update_forall(layer)
        self.assign_forall(layer)

    def seq_layers(self,layers):
        last = None
        for layer in layers[::-1]:
            if last!=None: last.last_layer = layer
            layer.next_layer = last
            last = layer

    def __repr__(self):
        return str(self.lmap)

    def immitate_func(self,obj,name,func):
        setattr(obj,name,func)

    def for_all_functions(self,obj):
        return dict(inspect.getmembers(obj,
#                   predicate=lambda x:hasattr(x,"forall") if inspect.ismethod(x) else False))
                  predicate=lambda x:hasattr(x,"forall")))

    def update_forall(self,obj):
        self.forall_pool.update(self.for_all_functions(obj))

    def assign_forall(self,obj):
        for name,f in self.forall_pool.items():
            self.immitate_func(obj,name,f)

class Loop:
    """
    Basic loop class
    """
    _loop_layer = 0.
    def __init__(self,iterable = [],name=None):
        self.iterable  = iterable
        self.next_layer = None
        self.name = name if name!=None else self.__class__.__name__

        if hasattr(iterable,"_loop_layer"):
            self._loop_layer = iterable._loop_layer + 1
            iterable.core.new_layer(self)
        else:
            self.core = StorageCore(self)

    def __call__(self):
        pass

    def __repr__(self):
        return f"layer:>>>{self.name}"

    def on(self,iterable):
        self.iterable = iterable
        self._loop_layer = iterable._loop_layer + 1

    def func_detail(self,f):
        detail = dict({"🐍func_name":f.__name__,
                       "⛰doc":f.__doc__,
                       "😝var":",".join(f.__code__.co_varnames),
                       "😜names":",".join(f.__code__.co_names),
                      })
        return detail

    def summary(self):
        rt = f"Loop layer {self.name} summary:\n"
        rt+= "="*50+"\n"
        funcs = []
        for idx,layer in self.core.lmap.items():
            rt+= f"🍰layer{idx}\t{str(layer)}\n"
            for fname,f in self.core.for_all_functions(layer).items():
                if id(f) not in funcs:
                    rt+="\t"
                    rt+="\n\t".join(f"[{k}]\t{v}" for k,v in self.func_detail(f).items())
                    rt+="\n\t"+"."*35+"\n"
                    funcs.append(id(f))
            rt+="-"*50+"\n"
        rt+= "="*50+"\n"
        print(rt)

    def __len__(self):
        return self.iterable.__len__()

    def run(self,):
        """
        Run through iteration
        run start_call for every layer on the start of iteration
            run __call__ for every layer for each iteration
        run end_call for every layer on the end of iteration
        """
        first = self.layers[0]
        self.refresh_i()
        first.start_callon()
        for element in first:
            first()
        first.end_callon()

    def refresh_i(self):
        self.core.i=-1
        self.core.epoch+=1

    def update_i(self):
        self.core.i+=1

    def callon(self):
        self()
        self.iter_cb()

    def start_callon(self):
        self.start_call()
        self.start_cb()

    def end_callon(self):
        self.end_call()
        self.end_cb()

    def iter_cb(self):
        """
        call back during each iteration
        """
        if self.next_layer!=None:
            self.next_layer.callon()

    def start_cb(self):
        """
        callback at the start of iteration
        """
        if self.next_layer!=None:
            self.next_layer.start_callon()

    def end_cb(self):
        """
        callback at the end of iteration
        """
        if self.next_layer!=None:
            self.next_layer.end_callon()

    def start_call(self):
        pass

    def end_call(self):
        pass

    def __iter__(self,):
        for element in self.iterable:
            if self._loop_layer ==0:
                self.update_i()
            self.core.element = element
            self.callon()
            yield self.element

    def __getattr__(self,k):
        return getattr(self.core,k)

    def is_newloop(self):
        """
        return Bool:Is this a new loop ready to start
        """
        return (self.i==-1 or self.i==self.__len__()-1)

# Cell
class ProgressBar(Loop):
    def __init__(self,iterable=[],jupyter = True,mininterval = 1e-1):
        super().__init__(iterable,"Progressb Bar")

        if jupyter: # jupyter widget
            from tqdm.notebook import tqdm

        else: # compatible for console print
            from tqdm import tqdm

        self.tqdm = tqdm
        self.mininterval = mininterval
        self.data = dict()

    @method4all
    def pgbar_data(self,data):
        """
        update progress bar with python dictionary
        data: python dictionary
        """
        self.t.set_postfix(data)

    @method4all
    def pgbar_description(self,text):
        """
        update progress bar prefix with text string
        """
        self.t.set_description_str(f"{text}")

    def start_call(self):
        self.create_bar()

    def end_call(self):
        self.t.close()

    def __call__(self):
        self.t.update(1)

    def create_bar(self):
        self.t = self.tqdm(total=len(self.iterable),
                           mininterval=self.mininterval)

# Cell
class Tolerate(Loop):
    """
    Tolerate any error happened downstream
    layer2 = Tolerate(upstream_iterable)
    # build downstream tasks
    layer3 = OtherApplication(layer2)
    layer3.run()
    # show the happened error message
    layer3.error_list()
    """
    def __init__(self,iterable = []):
        super().__init__(iterable,)
        self.errors = list()

    @method4all
    def error_list(self):
        """
        A list of errors happend so far
        """
        return self.errors

    def callon(self):
        """
        Usual callon method like other loop, but tolerate all errors
        """
        try:
            # this will only capture the downstream error
            self.iter_cb()
        except Exception as e:
            self.errors.append(dict(stage="middle",i=self.i,epoch=self.epoch,error=e))

    def start_callon(self):
        try:
            self.start_cb()
        except Exception as e:
            self.errors.append(dict(stage="start",i=self.i,epoch=self.epoch,error=e))

    def end_callon(self):
        try:
            self.end_cb()
        except Exception as e:
            self.errors.append(dict(stage="end",i=self.i,epoch=self.epoch,error=e))

class LambdaCall(Loop):
    def __init__(self,iterable = [],func = lambda x:x):
        super().__init__(iterable,name=f"Lambda<{hex(id(self))}>")
        self.func = func

    def __call__(self):
        self.func(self)

# Cell
class Event(Loop):
    """
    An event is the landmark with in 1 iteration
    """
    def __init__(self,iterable=[],event_name = None,cbs = []):
        super().__init__(iterable,event_name)
        self.cbs = cbs

        def call(self):
            return self.__call__()
        call.__name__ = f"on_{event_name}"
        call.__doc__ = f"""
            Excute callback for event:{event_name}
        """
        setattr(self,call.__name__,MethodType(method4all(call),self))

        def set_(self,f):
            return self.on(f)
        set_.__name__ = f"set_{event_name}"
        set_.__doc__ = f"""
            Append new callback for event:{event_name}
            Use this function as decorator
        """
        setattr(self,set_.__name__,MethodType(method4all(set_),self))

    def __call__(self):
        for cb in self.cbs:
            cb(self)

    def on(self,f):
        def wrapper(self):
            return f(self)
        self.cbs.append(wrapper)
        return wrapper