from typing import List, Set, FrozenSet, Tuple

from ImportData import import_data

day: int = 17

sample_data: List[str] = [
    '.#.',
    '..#',
    '###'
]

data_structure: type = FrozenSet["Cube"]


class Cube:
    cubes: Set["Cube"] = set()

    def __init__(self, x: int, y: int, z: int, active: bool = False):
        self.w = None
        self.x = x
        self.y = y
        self.z = z
        self.active = active
        self.next_state = active

    @classmethod
    def initialize(cls, starter_data: List[str]) -> FrozenSet["Cube"]:
        cls.cubes = set()
        z: int = 0
        for y in range(len(starter_data)):
            for x in range(len(starter_data[y])):
                active: bool = starter_data[y][x] == '#'
                cls.cubes.add(Cube(x, 2 - y, z, active))
        return frozenset(cls.cubes)

    @classmethod
    def get_cube(cls, x: int, y: int, z: int, w: int = 0) -> "Cube":
        matches: Set["Cube"] = {cube for cube in cls.cubes if cube.x == x and cube.y == y and cube.z == z}
        if len(matches) == 0:
            cube: "Cube" = Cube(x, y, z)
            cls.cubes.add(cube)
            return cube
        elif len(matches) == 1:
            return matches.pop()
        else:
            raise RuntimeError(f'Found {len(matches)} cubes with coordinates ({x}, {y}, {z}).')

    @classmethod
    def absent(cls, x: int, y: int, z: int, w: int = 0) -> bool:
        matches: Set["Cube"] = {cube for cube in cls.cubes if cube.x == x and cube.y == y and cube.z == z}
        return len(matches) == 0

    @classmethod
    def get_bounds(cls) -> Tuple[int, ...]:
        minimum_x: int = min({cube.x for cube in cls.cubes})
        maximum_x: int = max({cube.x for cube in cls.cubes})
        minimum_y: int = min({cube.y for cube in cls.cubes})
        maximum_y: int = max({cube.y for cube in cls.cubes})
        minimum_z: int = min({cube.z for cube in cls.cubes})
        maximum_z: int = max({cube.z for cube in cls.cubes})
        return maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z

    @classmethod
    def expand_space(cls) -> FrozenSet["Cube"]:
        maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z = cls.get_bounds()
        for x in range(minimum_x - 1, maximum_x + 2):
            for y in range(minimum_y - 1, maximum_y + 2):
                for z in range(minimum_z - 1, maximum_z + 2):
                    if cls.absent(x, y, z):
                        cls.cubes.add(Cube(x, y, z))
        return frozenset(cls.cubes)

    def number_of_active_neighbors(self) -> int:
        active_neighbors: Set["Cube"] = {cube for cube in self.cubes if
                                         self.x - 1 <= cube.x <= self.x + 1 and
                                         self.y - 1 <= cube.y <= self.y + 1 and
                                         self.z - 1 <= cube.z <= self.z + 1 and
                                         cube.active}
        if self in active_neighbors:
            active_neighbors.remove(self)
        return len(active_neighbors)

    def activate(self) -> None:
        self.next_state = True

    def deactivate(self) -> None:
        self.next_state = False

    def prep_change(self) -> None:
        active_neighbors: int = self.number_of_active_neighbors()
        if self.active:
            self.next_state = 2 <= active_neighbors <= 3
        else:
            self.next_state = active_neighbors == 3

    @classmethod
    def commit_changes(cls) -> FrozenSet["Cube"]:
        for cube in cls.cubes:
            cube.active = cube.next_state
        return frozenset(cls.cubes)

    @classmethod
    def print_cubes(cls) -> None:
        maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z = cls.get_bounds()
        for z in range(minimum_z, maximum_z + 1):
            print(f'z={z}')
            for y in range(maximum_y, minimum_y - 1, -1):
                for x in range(minimum_x, maximum_x + 1):
                    print('#' if cls.get_cube(x, y, z).active else '.', end='')
                print()
            print()

    def __repr__(self) -> str:
        return f'({self.x},{self.y},{self.z}): {self.active}'


class Hypercube(Cube):
    def __init__(self, x: int, y: int, z: int, w: int, active: bool = False):
        super().__init__(x, y, z, active)
        self.w = w

    @classmethod
    def initialize(cls, starter_data: List[str]) -> FrozenSet[Cube]:
        cls.cubes = set()
        z: int = 0
        w: int = 0
        for y in range(len(starter_data)):
            for x in range(len(starter_data[y])):
                active: bool = starter_data[y][x] == '#'
                cls.cubes.add(Hypercube(x, 2 - y, z, w, active))
        return frozenset(cls.cubes)

    @classmethod
    def get_cube(cls, x: int, y: int, z: int, w: int = 0) -> Cube:
        matches: Set[Cube] = {cube for cube in cls.cubes if cube.x == x and cube.y == y and cube.z == z and cube.w == w}
        if len(matches) == 0:
            cube: "Hypercube" = Hypercube(x, y, z, w)
            cls.cubes.add(cube)
            return cube
        elif len(matches) == 1:
            return matches.pop()
        else:
            raise RuntimeError(f'Found {len(matches)} cubes with coordinates ({x}, {y}, {z}).')

    @classmethod
    def absent(cls, x: int, y: int, z: int, w: int = 0) -> bool:
        matches: Set["Cube"] = {cube for cube in cls.cubes if
                                cube.x == x and cube.y == y and cube.z == z and cube.w == w}
        return len(matches) == 0

    @classmethod
    def get_bounds(cls) -> Tuple[int, ...]:
        maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z = super(Hypercube, cls).get_bounds()
        minimum_w: int = min({cube.w for cube in cls.cubes})
        maximum_w: int = max({cube.w for cube in cls.cubes})
        return maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z, minimum_w, maximum_w

    @classmethod
    def expand_space(cls) -> FrozenSet["Cube"]:
        maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z, minimum_w, maximum_w = cls.get_bounds()
        for x in range(minimum_x - 1, maximum_x + 2):
            for y in range(minimum_y - 1, maximum_y + 2):
                for z in range(minimum_z - 1, maximum_z + 2):
                    for w in range(minimum_w - 1, maximum_w + 2):
                        if cls.absent(x, y, z, w):
                            cls.cubes.add(Hypercube(x, y, z, w))
        return frozenset(cls.cubes)

    def number_of_active_neighbors(self) -> int:
        active_neighbors: Set["Cube"] = {cube for cube in self.cubes if
                                         self.x - 1 <= cube.x <= self.x + 1 and
                                         self.y - 1 <= cube.y <= self.y + 1 and
                                         self.z - 1 <= cube.z <= self.z + 1 and
                                         self.w - 1 <= cube.w <= self.w + 1 and
                                         cube.active}
        if self in active_neighbors:
            active_neighbors.remove(self)
        return len(active_neighbors)

    @classmethod
    def print_cubes(cls) -> None:
        maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z, minimum_w, maximum_w = cls.get_bounds()
        for w in range(minimum_w, maximum_w + 1):
            for z in range(minimum_z, maximum_z + 1):
                print(f'z={z}, w={w}')
                for y in range(maximum_y, minimum_y - 1, -1):
                    print(f'{str(abs(y)).rjust(3)} ', end='')
                    for x in range(minimum_x, maximum_x + 1):
                        print('#' if cls.get_cube(x, y, z, w).active else '.', end='')
                    print()
                print('    ', end='')
                for x in range(minimum_x, maximum_x + 1):
                    print(abs(x) // 10, end='')
                print()
                print('    ', end='')
                for x in range(minimum_x, maximum_x + 1):
                    print(abs(x) % 10, end='')
                print()
                print()
            print()

    def __repr__(self) -> str:
        return f'({self.x},{self.y},{self.z},{self.w}): {self.active}'

    @classmethod
    def remove_inactive_hyperplane(cls, dimension: str, value: int):
        trimming: bool = False
        hyperplane: Set[Cube] = {cube for cube in cls.cubes if getattr(cube, dimension) == value and cube.active}
        if len(hyperplane) == 0:
            cls.cubes = cls.cubes - {cube for cube in cls.cubes if getattr(cube, dimension) == value}
            trimming = True
        return trimming

    @classmethod
    def trim_space(cls) -> None:
        trimming: bool = True
        while trimming:
            trimming = False
            maximum_x, maximum_y, maximum_z, minimum_x, minimum_y, minimum_z, minimum_w, maximum_w = cls.get_bounds()
            trimming |= cls.remove_inactive_hyperplane('x', minimum_x)
            trimming |= cls.remove_inactive_hyperplane('x', maximum_x)
            trimming |= cls.remove_inactive_hyperplane('y', minimum_y)
            trimming |= cls.remove_inactive_hyperplane('y', maximum_y)
            trimming |= cls.remove_inactive_hyperplane('z', minimum_z)
            trimming |= cls.remove_inactive_hyperplane('z', maximum_z)
            trimming |= cls.remove_inactive_hyperplane('w', minimum_w)
            trimming |= cls.remove_inactive_hyperplane('w', maximum_w)


def parse_data1(data: List[str]) -> data_structure:
    return Cube.initialize(data)


def parse_data2(data: List[str]) -> data_structure:
    return Hypercube.initialize(data)


def part1(data: data_structure) -> int:
    cubes: FrozenSet[Cube] = data
    for cycle in range(6):
        # print(f'Cycle: {cycle}')
        cubes = Cube.expand_space()
        # Cube.print_cubes()
        for cube in cubes:
            cube.prep_change()
        cubes = Cube.commit_changes()
    # print('Final:')
    # Cube.print_cubes()
    return len({cube for cube in cubes if cube.active})


def part2(data: data_structure) -> int:
    cubes: FrozenSet[Cube] = data
    for cycle in range(6):
        print(f'Cycle: {cycle}')
        Hypercube.trim_space()
        # Hypercube.print_cubes()
        cubes = Hypercube.expand_space()
        for cube in cubes:
            cube.prep_change()
        cubes = Hypercube.commit_changes()
    print('Final:')
    # Hypercube.trim_space()
    # Hypercube.print_cubes()
    return len({cube for cube in cubes if cube.active})


if __name__ == '__main__':
    production_ready = True
    raw_data = import_data(day) if production_ready else sample_data
    print(part1(parse_data1(raw_data)))
    print(part2(parse_data2(raw_data)))