diff --git a/src/puzzle/Main.java b/src/puzzle/Main.java index 493f961..30c90af 100644 --- a/src/puzzle/Main.java +++ b/src/puzzle/Main.java @@ -62,7 +62,7 @@ public class Main { public static void main(String[] args) { var opts = parseArgs(args); - var res = SwedishGenerator.generatePuzzle(opts); + var res = generatePuzzle(opts); if (res == null) { System.out.println("No solution found within tries."); System.exit(1); @@ -117,6 +117,70 @@ public class Main { } } + // Package-private method for testing + static PuzzleResult generatePuzzle(Opts opts) { + var llmScores = loadScores(); + var tLoad0 = System.nanoTime(); + var dict = loadWords(opts.wordsPath, llmScores); + var tLoad1 = System.nanoTime(); + System.out.printf(Locale.ROOT, "LOAD_WORDS: %.3fs%n %s words%n", (tLoad1 - tLoad0) / 1e9, dict.words.size()); + + if (opts.threads > 1) { + System.out.println("Running in multi-threaded mode with " + opts.threads + " threads..."); + var executor = Executors.newFixedThreadPool(opts.threads); + try { + var tasks = new ArrayList>(); + for (int i = 1; i <= opts.tries; i++) { + final int attempt = i; + tasks.add(() -> { + var threadRng = new Rng(opts.seed + attempt); + var mask = generateMask(threadRng, dict.lenCounts, opts.pop, opts.gens, false); + var filled = fillMask(threadRng, mask, dict.index, llmScores, 200, 60000, false); + + if (filled.ok && (opts.minSimplicity <= 0 || filled.simplicity >= opts.minSimplicity)) { + System.out.println("\nSolution found on attempt " + attempt); + return new PuzzleResult(mask, filled); + } + throw new RuntimeException("No solution found in attempt " + attempt); + }); + } + return executor.invokeAny(tasks); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + // all failed + } finally { + executor.shutdownNow(); + } + return null; + } else { + var rng = new Rng(opts.seed); + for (var attempt = 1; attempt <= opts.tries; attempt++) { + System.out.println("\nAttempt " + attempt + "/" + opts.tries); + + var tMask0 = System.nanoTime(); + var mask = generateMask(rng, dict.lenCounts, opts.pop, opts.gens, true); + var tMask1 = System.nanoTime(); + System.out.printf(Locale.ROOT, "MASK: %.3fs%n", (tMask1 - tMask0) / 1e9); + + var tFill0 = System.nanoTime(); + var filled = fillMask(rng, mask, dict.index, llmScores, 200, 60000, true); + var tFill1 = System.nanoTime(); + System.out.printf(Locale.ROOT, "FILL: %.3fms | Simplicity: %.2f%n", (tFill1 - tFill0) / 1e6, filled.simplicity); + + if (filled.ok && (opts.minSimplicity <= 0 || filled.simplicity >= opts.minSimplicity)) { + return new PuzzleResult(mask, filled); + } + if (filled.ok) { + System.out.printf(Locale.ROOT, "Puzzle simplicity %.2f is below min %.2f, retrying...%n", + filled.simplicity, opts.minSimplicity); + } + } + } + + return null; + } + private static String toJson(ExportFormat.ExportedPuzzle puzzle, String date, String theme) { var sb = new StringBuilder(); sb.append("{\n"); @@ -189,12 +253,12 @@ public class Main { if (content.isEmpty() || content.equals("[]")) { content = "[\n " + newRecordJson + "\n]"; } else { - int firstBracket = content.indexOf('['); - if (firstBracket != -1) { - content = content.substring(0, firstBracket + 1) + "\n " + newRecordJson + "," + content.substring(firstBracket + 1); - } else { - content = "[\n " + newRecordJson + "\n]"; - } + int firstBracket = content.indexOf('['); + if (firstBracket != -1) { + content = content.substring(0, firstBracket + 1) + "\n " + newRecordJson + "," + content.substring(firstBracket + 1); + } else { + content = "[\n " + newRecordJson + "\n]"; + } } Files.writeString(indexPath, content, StandardCharsets.UTF_8); System.out.println("Updated index.json at: " + indexPath); @@ -202,4 +266,4 @@ public class Main { System.err.println("Failed to update index.json: " + e.getMessage()); } } -} \ No newline at end of file +} diff --git a/src/test/java/puzzle/MainTest.java b/src/test/java/puzzle/MainTest.java new file mode 100644 index 0000000..b7eb2d6 --- /dev/null +++ b/src/test/java/puzzle/MainTest.java @@ -0,0 +1,29 @@ +package puzzle; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class MainTest { + + @Test + public void testGeneratePuzzle() { + // Arrange + var opts = new Main.Opts(); + opts.seed = 1234; + opts.pop = 18; + opts.gens = 300; + opts.wordsPath = "src/test/resources/puzzle/pool.txt"; + opts.minSimplicity = 0; + opts.threads = 1; + opts.tries = 1; + + // Act + var result = Main.generatePuzzle(opts); + + // Assert + assertNotNull(result); + assertNotNull(result.mask()); + assertNotNull(result.filled()); + assertTrue(result.filled().ok); + } +}