from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import ClassVar
class NodeABC(ABC):
@abstractmethod
def __iter__(self):
pass
@dataclass
class Node:
value: int
left: NodeABC = None
right: NodeABC = None
def __post_init__(self):
self.parent = None
if self.left:
self.left.parent = self
if self.right:
self.right.parent = self
def __iter__(self):
return InOrderIterator(self)
@dataclass
class InOrderIterator:
root: Node
def __post_init__(self):
self.current = root
self.yielded_start = False
while self.current.left:
self.current = self.current.left
def __next__(self):
if not self.yielded_start:
self.yielded_start = True
return self.current
if self.current.right:
self.current = self.current.right
while self.current.left:
self.current = self.current.left
return self.current
else:
p = self.current.parent
while p and self.current == p.right:
self.current = p
p = p.parent
self.current = p
if self.current:
return self.current
else:
raise StopIteration
if __name__ == '__main__':
root = Node(1, Node(2), Node(3))
it = iter(root)
print([next(it).value for x in range(3)])
for x in root:
print(x.value)
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import ClassVar
class NodeABC(ABC):
@abstractmethod
def __iter__(self):
pass
@dataclass
class Node:
value: int
left: NodeABC = None
right: NodeABC = None
def __post_init__(self):
self.parent = None
if self.left:
self.left.parent = self
if self.right:
self.right.parent = self
def traverse_in_order(root):
def traverse(current):
if current.left:
yield from traverse(current.left)
yield current
if current.right:
yield from traverse(current.right)
for node in traverse(root):
yield node
if __name__ == '__main__':
root = Node(1, Node(2), Node(3))
for y in traverse_in_order(root):
print(y.value)