from datetime import datetime


import rayleaf
import rayleaf.entities as entities
from rayleaf.entities import Server, Client


def fedavg_cnn(
    num_rounds: int = 100,
    eval_every: int = 10,
    num_clients: int = 200,
    clients_per_round: int = 40,
    client_lr: float = 0.05,
    batch_size: int = 64,
    seed: int = 0,
    num_epochs: int = 10,
    gpus_per_client_cluster: float = 1,
    num_client_clusters: int = 8,
    save_model: bool = False,
    use_grads = False,
    notes: str = None
):
    curr_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

    if use_grads:
        def make_fedavg_server_client():
            class FedAvgClient(Client):
                def train(self):
                    self.train_model(compute_grads=True)

                    return self.grads
            

            class FedAvgServer(Server):
                def update_layer(self, current_params, updates: list, client_num_samples: list, num_clients: int):
                    average_grads = 0
                    for i in range(num_clients):
                        average_grads += updates[i] * client_num_samples[i]
                    
                    average_grads /= self.num_train_samples

                    return current_params + average_grads
            

            return FedAvgServer, FedAvgClient

        
        FedAvgServer, FedAvgClient = make_fedavg_server_client()
    else:
        FedAvgServer, FedAvgClient = Server, Client


    rayleaf.run_experiment(
        dataset = "femnist",
        dataset_dir = "../data/femnist/",
        output_dir= f"output/fedavg_cnn-{curr_time}/",
        model = "cnn",
        num_rounds = num_rounds,
        eval_every = eval_every,
        ServerType=FedAvgServer,
        client_types=[(FedAvgClient, num_clients)],
        clients_per_round = clients_per_round,
        client_lr = client_lr,
        batch_size = batch_size,
        seed = seed,
        use_val_set = False,
        num_epochs = num_epochs,
        gpus_per_client_cluster = gpus_per_client_cluster,
        num_client_clusters = num_client_clusters,
        save_model = save_model,
        notes = notes
    )
