Solution 1: Using Stack

Intuition:

To find the kth smallest element in a binary search tree (BST), we can perform an in-order traversal of the BST. In-order traversal of a BST visits the nodes in sorted order. By keeping track of the count of visited nodes during the traversal, we can stop the traversal when we reach the kth node, which will be the kth smallest element in the BST.

Solution:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def kth_smallest(root: TreeNode, k: int) -> int:
    # Stack to simulate recursive in-order traversal
    stack = []
    # Pointer to track the current node
    current = root
    # Counter to keep track of visited nodes
    count = 0
    
    # Traverse the BST until we reach the kth smallest element
    while True:
        # Move to the leftmost node in the subtree
        while current is not None:
            stack.append(current)
            current = current.left
        
        # Backtrack to the parent node
        current = stack.pop()
        # Increment the counter
        count += 1
        
        # Check if we have reached the kth smallest element
        if count == k:
            return current.val
        
        # Move to the right subtree
        current = current.right

# Test cases
root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.left.right = TreeNode(2)
result = kth_smallest(root, 1)
print(result)  # Output should be 1

root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(6)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.left.left.left = TreeNode(1)
result = kth_smallest(root, 3)
print(result)  # Output should be 3

Time Complexity:

  • The time complexity of the kth_smallest function is O(H + k), where H is the height of the BST.
  • In the worst case, the height of the BST can be equal to the number of nodes in the tree, resulting in O(n + k) time complexity, where n is the number of nodes in the BST.
  • However, on average, for a balanced BST, the time complexity is O(log(n) + k)).

Space Complexity:

  • The space complexity of the algorithm is O(H), where H is the height of the BST.
  • In the worst case, the space complexity can be O(n), where n is the number of nodes in the BST, if the tree is skewed. However, for a balanced BST, the space complexity is O(log(n)).

Solution 2: Recursion

Intuition:

There is another solution involves utilizing a recursive approach with a depth-first search (DFS). We can perform an in-order traversal recursively, keeping track of the count of visited nodes and returning the value of the kth smallest node when the count reaches k.

Solution:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def kth_smallest(root: TreeNode, k: int) -> int:
    # Counter to keep track of visited nodes
    counter = {'count': 0}
    # Variable to store the kth smallest element
    result = {'value': None}
    
    # Helper function for recursive in-order traversal
    def inorder_traversal(node: TreeNode):
        if node is None:
            return
        
        # Traverse left subtree
        inorder_traversal(node.left)
        
        # Increment count and check if it's kth smallest element
        counter['count'] += 1
        if counter['count'] == k:
            result['value'] = node.val
            return
        
        # Traverse right subtree
        inorder_traversal(node.right)
    
    # Start in-order traversal from the root
    inorder_traversal(root)
    
    # Return the kth smallest element
    return result['value']

# Test cases
root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.left.right = TreeNode(2)
result = kth_smallest(root, 1)
print(result)  # Output should be 1

root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(6)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.left.left.left = TreeNode(1)
result = kth_smallest(root, 3)
print(result)  # Output should be 3

The time and space complexity of this solution is the same as those of Solution 1.