import stainless.proof._ import stainless.lang._ import stainless.collection._ object BinarySearchTree { sealed abstract class Tree[T] case class Node[T](left: Tree[T], value: T, right: Tree[T]) extends Tree[T] case class Leaf[T]() extends Tree[T] def bstDelete(elem: BigInt, tree: Tree[BigInt]): Tree[BigInt] = { require( isBST(tree) ) tree match { case Leaf() => Leaf[BigInt]() case Node(Leaf(),value,right) if (value == elem) => right case Node(left,value,Leaf()) if (value == elem) => left case Node(left,value,right) if (value == elem) => val ret = bstDeleteMin(right) Node(left,ret._2,ret._1) case Node(left,value,right) if (elem < value) => Node(bstDelete(elem,left),value,right) case Node(left,value,right) if (elem > value) => Node(left,value,bstDelete(elem,right)) } } ensuring { res: Tree[BigInt] => isBST(res) } def bstDeleteMin(tree: Tree[BigInt]): (Tree[BigInt], BigInt) = { tree match { case Leaf() => (Leaf(), 0) case Node(Leaf(),value,right) => (right, value) case Node(left,value,right) => val ret = bstDeleteMin(left) (Node(ret._1,value,right), ret._2) } } def isBST(tree: Tree[BigInt]) : Boolean = { tree match { case Leaf() => true case Node(left,value,right) => val tMax = treeMax(left) val tMin = treeMin(right) (!tMax._1 || value > tMax._2) && (!tMin._1 || value < tMin._2) && isBST(left) && isBST(right) } } def treeMax(tree: Tree[BigInt]) : (Boolean, BigInt) = { tree match { case Leaf() => (false, 0) case Node(left,value,right) => val lMax = treeMax(left) val rMax = treeMax(right) if (lMax._1 && rMax._1) { if (lMax._2 > rMax._2) { if (lMax._2 > value) (true, lMax._2) else (true, value) } else { if (rMax._2 > value) (true, rMax._2) else (true, value) } } else { if (lMax._1) { if (lMax._2 > value) (true, lMax._2) else (true, value) } else { if (rMax._1 && rMax._2 > value) { (true, rMax._2) } else { (true, value) } } } } } def treeMin(tree: Tree[BigInt]) : (Boolean, BigInt) = { tree match { case Leaf() => (false, 0) case Node(left,value,right) => val lMin = treeMin(left) val rMin = treeMin(right) if (lMin._1 && rMin._1) { if (lMin._2 < rMin._2) { if (lMin._2 < value) (true, lMin._2) else (true, value) } else { if (rMin._2 < value) (true, rMin._2) else (true, value) } } else { if (lMin._1) { if (lMin._2 < value) (true, lMin._2) else (true, value) } else { if (rMin._1 && rMin._2 < value) { (true, rMin._2) } else { (true, value) } } } } } }