"""High level Verifiable Delay Function using keccak (sha3)."""
from hashlib import sha3_256
from typing import Union
import time
from statistics import mean
from math import ceil

from .mimc import forward_mimc, reverse_mimc, is_fast
"""
Kevin Froman 2020

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
DEFAULT_ROUNDS = 8000


def _sha3_256_hash(data: bytes) -> bytes:
    sha3 = sha3_256()
    sha3.update(data)
    return sha3.digest()


def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str:
    assert rounds > 1
    input_data: int = _sha3_256_hash(data)
    if dec:
        return int.from_bytes(reverse_mimc(input_data, rounds), "big")
    return reverse_mimc(input_data, rounds).hex().lstrip('\0')


def vdf_verify(
        data: bytes,
        test_hash: str,
        rounds: int = DEFAULT_ROUNDS) -> bool:
    """Verify data for test_hash generated by vdf_create."""
    assert rounds > 1
    should_match = _sha3_256_hash(data).lstrip(b'\0')
    if isinstance(test_hash, int):
        test_hash = test_hash.to_bytes((test_hash.bit_length() + 7) // 8, "big")
    else:
        try:
            test_hash = bytes.fromhex(test_hash)
        except ValueError:
            return False
    return forward_mimc(test_hash, rounds) == should_match


def profile_cpu_speed(seconds=1) -> float:
    n = 2
    start = time.time()
    done = False
    results = []
    try:
        for _ in range(20):
            done = False
            n = 2
            start = time.time()
            while not done:
                vdf_create(b't', n)
                if time.time() - start >= seconds:
                    break
                n += 1
            results.append(n)
    except KeyboardInterrupt:
        pass
    return ceil(mean(results))


if __name__ == "__main__":
    print("Calculate how may rounds are needed for X seconds (influenced by system processes): ")
    seconds = int(input("Seconds: "))
    print("Rounds:", profile_cpu_speed(seconds))
