test: Add basic unit tests for Main class

Co-authored-by: aider (ollama/qwen2.5-coder:14b-instruct) <aider@aider.chat>
This commit is contained in:
mike
2025-12-27 02:54:18 +01:00
parent 5016bb1974
commit de7e34d594
2 changed files with 101 additions and 8 deletions

View File

@@ -62,7 +62,7 @@ public class Main {
public static void main(String[] args) { public static void main(String[] args) {
var opts = parseArgs(args); var opts = parseArgs(args);
var res = SwedishGenerator.generatePuzzle(opts); var res = generatePuzzle(opts);
if (res == null) { if (res == null) {
System.out.println("No solution found within tries."); System.out.println("No solution found within tries.");
System.exit(1); System.exit(1);
@@ -117,6 +117,70 @@ public class Main {
} }
} }
// Package-private method for testing
static PuzzleResult generatePuzzle(Opts opts) {
var llmScores = loadScores();
var tLoad0 = System.nanoTime();
var dict = loadWords(opts.wordsPath, llmScores);
var tLoad1 = System.nanoTime();
System.out.printf(Locale.ROOT, "LOAD_WORDS: %.3fs%n %s words%n", (tLoad1 - tLoad0) / 1e9, dict.words.size());
if (opts.threads > 1) {
System.out.println("Running in multi-threaded mode with " + opts.threads + " threads...");
var executor = Executors.newFixedThreadPool(opts.threads);
try {
var tasks = new ArrayList<Callable<PuzzleResult>>();
for (int i = 1; i <= opts.tries; i++) {
final int attempt = i;
tasks.add(() -> {
var threadRng = new Rng(opts.seed + attempt);
var mask = generateMask(threadRng, dict.lenCounts, opts.pop, opts.gens, false);
var filled = fillMask(threadRng, mask, dict.index, llmScores, 200, 60000, false);
if (filled.ok && (opts.minSimplicity <= 0 || filled.simplicity >= opts.minSimplicity)) {
System.out.println("\nSolution found on attempt " + attempt);
return new PuzzleResult(mask, filled);
}
throw new RuntimeException("No solution found in attempt " + attempt);
});
}
return executor.invokeAny(tasks);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
// all failed
} finally {
executor.shutdownNow();
}
return null;
} else {
var rng = new Rng(opts.seed);
for (var attempt = 1; attempt <= opts.tries; attempt++) {
System.out.println("\nAttempt " + attempt + "/" + opts.tries);
var tMask0 = System.nanoTime();
var mask = generateMask(rng, dict.lenCounts, opts.pop, opts.gens, true);
var tMask1 = System.nanoTime();
System.out.printf(Locale.ROOT, "MASK: %.3fs%n", (tMask1 - tMask0) / 1e9);
var tFill0 = System.nanoTime();
var filled = fillMask(rng, mask, dict.index, llmScores, 200, 60000, true);
var tFill1 = System.nanoTime();
System.out.printf(Locale.ROOT, "FILL: %.3fms | Simplicity: %.2f%n", (tFill1 - tFill0) / 1e6, filled.simplicity);
if (filled.ok && (opts.minSimplicity <= 0 || filled.simplicity >= opts.minSimplicity)) {
return new PuzzleResult(mask, filled);
}
if (filled.ok) {
System.out.printf(Locale.ROOT, "Puzzle simplicity %.2f is below min %.2f, retrying...%n",
filled.simplicity, opts.minSimplicity);
}
}
}
return null;
}
private static String toJson(ExportFormat.ExportedPuzzle puzzle, String date, String theme) { private static String toJson(ExportFormat.ExportedPuzzle puzzle, String date, String theme) {
var sb = new StringBuilder(); var sb = new StringBuilder();
sb.append("{\n"); sb.append("{\n");
@@ -189,12 +253,12 @@ public class Main {
if (content.isEmpty() || content.equals("[]")) { if (content.isEmpty() || content.equals("[]")) {
content = "[\n " + newRecordJson + "\n]"; content = "[\n " + newRecordJson + "\n]";
} else { } else {
int firstBracket = content.indexOf('['); int firstBracket = content.indexOf('[');
if (firstBracket != -1) { if (firstBracket != -1) {
content = content.substring(0, firstBracket + 1) + "\n " + newRecordJson + "," + content.substring(firstBracket + 1); content = content.substring(0, firstBracket + 1) + "\n " + newRecordJson + "," + content.substring(firstBracket + 1);
} else { } else {
content = "[\n " + newRecordJson + "\n]"; content = "[\n " + newRecordJson + "\n]";
} }
} }
Files.writeString(indexPath, content, StandardCharsets.UTF_8); Files.writeString(indexPath, content, StandardCharsets.UTF_8);
System.out.println("Updated index.json at: " + indexPath); System.out.println("Updated index.json at: " + indexPath);

View File

@@ -0,0 +1,29 @@
package puzzle;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class MainTest {
@Test
public void testGeneratePuzzle() {
// Arrange
var opts = new Main.Opts();
opts.seed = 1234;
opts.pop = 18;
opts.gens = 300;
opts.wordsPath = "src/test/resources/puzzle/pool.txt";
opts.minSimplicity = 0;
opts.threads = 1;
opts.tries = 1;
// Act
var result = Main.generatePuzzle(opts);
// Assert
assertNotNull(result);
assertNotNull(result.mask());
assertNotNull(result.filled());
assertTrue(result.filled().ok);
}
}