We illustrate some key features of proving by learning with generative models based on Dependent Type Theory in the ProvingGround project. Some details of the project are available on its website.
We have a set $M$ with a binary operation $*$, and given elements $e_l$ and $e_m$. We also have terms correspoding to the assumptions
From these, we conclude that basic algebraic result that $e_l = e_r$. Note that the proof in its fully expanded form is fairly long, involving the correct instantiations of each of the identity axioms, a use of symmetry of equality and a use of transitivity of equality. In particular it has very low weight in a direct generative model.
However, our system finds the proof quite readily (much less than a minute even running synchronously in the notebook), due to a few key principles.
We import the locally published core of ProvingGround.
import $ivy.`io.github.siddhartha-gadgil::provingground-core-jvm:0.1.1-SNAPSHOT`
We import the relevant packages, including the one with our example.
import provingground._ , interface._, HoTT._
import learning._
import library._, MonoidSimple._
repl.pprinter.bind(translation.FansiShow.fansiPrint)
To set up a generator. For simplicity in this demo, we are not generating $\lambda$'s and $\Pi$-types
val tg = TermGenParams(lmW=0, piW=0)
We set up the initial state for learning.
val ts = TermState(dist1, dist1.map(_.typ))
We use the monix library for concurrency, though in this notebook we run code synchronously.
import monix.execution.Scheduler.Implicits.global
We synchronously evolve to the next state with our generative model.
val evT = tg.evolvedStateTask(ts, 0.00003)
val ev = evT.runSyncUnsafe()
We identify lemmas as types that we know to be inhabited but are not inhabited by terms in the generating set. We have kept things simple, so there are fewer lemmas to start with. But the main step is learning the weights of the lemmas.
The evolution of weights is generated as an iterator. As this is a structure that is generated just once, we store it as a stream, which is a lazy, infinite list.
import EntropyAtomWeight._
val its = evolvedLemmaIters(ev, 1, 1)
val streams = its.map{case (tp, it) => (tp, it.toStream)}
We see three snapshots of the evolution of weights, at times $0$, $1000$ and $10000$. We also plot the evolution below.
streams.map{case (tp, s) => (tp, s(0), s(1000), s(10000))}
import $ivy.`org.plotly-scala::plotly-almond:0.5.2`
import plotly._
import plotly.element._
import plotly.layout._
import plotly.Almond._
val scatters = streams.map{case (typ, s) => Scatter((1 to 50).map(i => (i * 50)), (1 to 50).map(i => s(i * 50)), name=typ.toString )}
val tsdata = Seq(scatters: _*)
tsdata.plot(title="Evolution of the lemma coefficients")
We pick the lemma that will give us the result. Note that there are four lemmas that have high weight after the evolution, two of which will work, so one could evolve along all of them (which we do in the interactive code).
val pf = "lemma" :: toTyp(streams(8)._1)
pf.typ
For quick further exploration, we use the derivatives of the recursive generating rules in the direction of the new lemma.
val tt = tg.nextTangStateTask(ev.result, TermState(FiniteDistribution.unif(pf), FiniteDistribution.empty), math.pow(10.0, -4))
val tts = tt.runSyncUnsafe()
The next step indeed gives a proof of the desired result $e_l = e_r$. Note that this has a high weight of 1.2%
tts.terms.filter(_.typ == eqM(l)(r))
val ss = streams.filter{case (tp, s) => s(10000) > s(1000)}.map{case (tp, s) => (tp, s(10000))}
val ls = FiniteDistribution(ss.map{case (tp, p) => Weighted("lemma" :: tp, p * 10)})
val ts1 = TermState((dist1 ++ ls).normalized(), dist1.map(_.typ))
val evT1 = tg.evolvedStateTask(ts1, 0.000003)
val ev1 = evT1.runSyncUnsafe()
ev1.result.terms.filter(_.typ == eqM(l)(r))
ev1.result.terms.map(_.typ).entropyVec
import Unify._
unifApply(trans, "lemma" :: ss(0)._1, Vector())
ss(0)._1
trans.typ
unifApply(sym, "lemma" :: ss(0)._1, Vector())
import Fold._
domain(trans)
val arg = "lemma" :: ss(0)._1
unify(M, arg.typ, _ => false)
arg.typ
appln(sym, arg)
appln(trans, arg)
val fn = appln(trans, arg).get
fn.typ