diff --git a/Makefile b/Makefile index c29a4c8..0d575a3 100644 --- a/Makefile +++ b/Makefile @@ -26,3 +26,5 @@ run_test_dfs: run_test_maze_gen: PYTHONPATH=src uv run pytest tests/test_MazeGenerator.py +run_test: + uv run pytest diff --git a/src/amaz_lib/MazeSolver.py b/src/amaz_lib/MazeSolver.py index 2efe587..2022c7f 100644 --- a/src/amaz_lib/MazeSolver.py +++ b/src/amaz_lib/MazeSolver.py @@ -1,7 +1,134 @@ from abc import ABC, abstractmethod from .Maze import Maze +import numpy as np class MazeSolver(ABC): + def __init__(self, start: tuple[int, int], end: tuple[int, int]) -> None: + self.start = (start[0] - 1, start[1] - 1) + self.end = (end[0] - 1, end[1] - 1) + @abstractmethod def solve(self, maze: Maze) -> str: ... + + +class AStar(MazeSolver): + + def __init__(self, start: tuple[int, int], end: tuple[int, int]) -> None: + super().__init__(start, end) + + def f(self, n): + def g(n: tuple[int, int]) -> int: + res = 0 + if n[0] < self.start[0]: + res += self.start[0] - n[0] + else: + res += n[0] - self.start[0] + if n[1] < self.start[1]: + res += self.start[1] - n[1] + else: + res += n[1] - self.start[1] + return res + + def h(n: tuple[int, int]) -> int: + res = 0 + if n[0] < self.end[0]: + res += self.end[0] - n[0] + else: + res += n[0] - self.end[0] + if n[1] < self.end[1]: + res += self.end[1] - n[1] + else: + res += n[1] - self.end[1] + return res + + try: + return g(n) + h(n) + except Exception: + return 1000 + + def best_path( + self, maze: np.ndarray, actual: tuple[int, int] + ) -> dict[str, int | None]: + print(actual) + path = { + "N": ( + self.f((actual[0], actual[1] - 1)) + if not maze[actual[0]][actual[1]].get_north() and actual[1] > 0 + else None + ), + "E": ( + self.f((actual[0] + 1, actual[1])) + if not maze[actual[0]][actual[1]].get_est() + and actual[0] < len(maze) - 1 + else None + ), + "S": ( + self.f((actual[0], actual[1] + 1)) + if not maze[actual[0]][actual[1]].get_south() + and actual[1] < len(maze[0]) - 1 + else None + ), + "W": ( + self.f((actual[0] - 1, actual[1])) + if not maze[actual[0]][actual[1]].get_west() and actual[0] > 0 + else None + ), + } + return { + k: v for k, v in sorted(path.items(), key=lambda item: item[0]) + } + + def get_opposit(self, dir: str) -> str: + match dir: + case "N": + return "S" + case "E": + return "W" + case "S": + return "N" + case "W": + return "E" + case _: + return "" + + def get_next_pos( + self, dir: str, actual: tuple[int, int] + ) -> tuple[int, int]: + match dir: + case "N": + return (actual[0], actual[1] - 1) + case "E": + return (actual[0] + 1, actual[1]) + case "S": + return (actual[0], actual[1] + 1) + case "W": + return (actual[0] - 1, actual[1]) + case _: + return actual + + def get_path( + self, actual: tuple[int, int], maze: np.ndarray, pre: str | None + ) -> str | None: + if actual == self.end: + return "" + paths = self.best_path(maze, actual) + for path in paths: + if paths[path] is None: + continue + if path != pre: + temp = self.get_path( + self.get_next_pos(path, actual), + maze, + self.get_opposit(path), + ) + if not temp is None: + return path + temp + return None + + def solve(self, maze: Maze) -> str: + print(maze) + res = self.get_path(self.start, maze.get_maze(), None) + if res is None: + raise Exception("Path not found") + return res diff --git a/src/amaz_lib/__init__.py b/src/amaz_lib/__init__.py index d772717..fda6b32 100644 --- a/src/amaz_lib/__init__.py +++ b/src/amaz_lib/__init__.py @@ -2,9 +2,9 @@ from .Cell import Cell from .Maze import Maze from .MazeGenerator import MazeGenerator, DepthFirstSearch from .MazeGenerator import Kruskal -from .MazeSolver import MazeSolver +from .MazeSolver import MazeSolver, AStar __version__ = "1.0.0" __author__ = "us" __all__ = ["Cell", "Maze", "MazeGenerator", - "MazeSolver", "DepthFirstSearch", "Kruskal"] + "MazeSolver", "AStar", "Kruskal", "DepthFirstSearch"] diff --git a/tests/test_MazeSolver.py b/tests/test_MazeSolver.py new file mode 100644 index 0000000..b8f2f4c --- /dev/null +++ b/tests/test_MazeSolver.py @@ -0,0 +1,19 @@ +from amaz_lib.Cell import Cell +import numpy as np +from amaz_lib import AStar, Maze, MazeSolver + + +def test_solver() -> None: + maze = Maze( + np.array( + [ + [Cell(value=13), Cell(value=3), Cell(value=11)], + [Cell(value=9), Cell(value=4), Cell(value=6)], + [Cell(value=12), Cell(value=5), Cell(value=7)], + ] + ) + ) + print(maze) + solver = AStar((1, 1), (3, 3)) + res = solver.solve(maze) + assert res == "ESWSEE"