fix: fixed retrain seg fault

This commit is contained in:
Lilian1024
2024-12-04 08:28:31 +01:00
parent 92534f9e4d
commit f74db484dd
2 changed files with 10 additions and 7 deletions

View File

@ -116,19 +116,19 @@ int main(int argc, char *argv[])
network_train(&network, data_dir, network_path, batch_pourcent, iterations, warmup, warmup_iterations, learning_rate, AdaFactor);
}
else if (strcmp(action, "retrain") == 0) // retrain network: ./network <network.csv> <data directory> <batch pourcent> <iterations>
else if (strcmp(action, "retrain") == 0) // retrain network: ./network retrain <network.csv> <data directory> <batch pourcent> <iterations>
{
if (argc < 4)
errx(EXIT_FAILURE, "missing arguments, usage: ./network <network.csv> <data directory> <batch pourcent> <iterations> [learning_rate]");
errx(EXIT_FAILURE, "missing arguments, usage: ./network retrain <network.csv> <data directory> <batch pourcent> <iterations> [learning_rate]");
double batch_pourcent = atof(argv[3]);
size_t iterations = (size_t)atoi(argv[4]);
double batch_pourcent = atof(argv[4]);
size_t iterations = (size_t)atoi(argv[5]);
double learning_rate = 0.1;
if (argc > 4)
learning_rate = atof(argv[5]);
if (argc > 5)
learning_rate = atof(argv[6]);
network_retrain(argv[1], argv[2], batch_pourcent, iterations, learning_rate);
network_retrain(combine_path(network_application_directory, argv[2]), combine_path(network_application_directory, argv[3]), batch_pourcent, iterations, learning_rate);
}
else if (strcmp(action, "use") == 0) // use network: ./network use <network.csv> <image path>
{

View File

@ -388,6 +388,9 @@ char* read_file(const char* file)
{
FILE * stream = fopen( file, "r" );
if (stream == NULL)
errx(EXIT_FAILURE, "read_file cannot open file: %s", file);
size_t str_len = 0;
char* buff = calloc(1, sizeof(char));