You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.0KB

  1. package main
  2. //import "fmt"
  3. import "math"
  4. import "math/big"
  5. //import "crypto/rand"
  6. //import "encoding/hex"
  7. import "github.com/clearmatics/bn256"
  8. //import "golang.org/x/crypto/sha3"
  9. // basically the Σ-protocol
  10. type InnerProduct struct {
  11. a, b *big.Int
  12. ls, rs []*bn256.G1
  13. }
  14. func (ip *InnerProduct) Size() int {
  15. return FIELDELEMENT_SIZE + FIELDELEMENT_SIZE + len(ip.ls)*POINT_SIZE + len(ip.rs)*POINT_SIZE
  16. }
  17. func NewInnerProductProof(ips *IPStatement, witness *IPWitness, salt *big.Int) *InnerProduct {
  18. var ip InnerProduct
  19. ip.generateInnerProductProof(ips.PrimeBase, ips.P, witness.L, witness.R, salt)
  20. return &ip
  21. }
  22. func (ip *InnerProduct) generateInnerProductProof(base *GeneratorParams, P *bn256.G1, as, bs *FieldVector, prev_challenge *big.Int) {
  23. n := as.Length()
  24. if n == 1 { // the proof is done, ls,rs are already in place
  25. ip.a = as.vector[0]
  26. ip.b = bs.vector[0]
  27. return
  28. }
  29. nPrime := n / 2
  30. asLeft := as.Slice(0, nPrime)
  31. asRight := as.Slice(nPrime, n)
  32. bsLeft := bs.Slice(0, nPrime)
  33. bsRight := bs.Slice(nPrime, n)
  34. gLeft := base.Gs.Slice(0, nPrime)
  35. gRight := base.Gs.Slice(nPrime, n)
  36. hLeft := base.Hs.Slice(0, nPrime)
  37. hRight := base.Hs.Slice(nPrime, n)
  38. cL := asLeft.InnerProduct(bsRight)
  39. cR := asRight.InnerProduct(bsLeft)
  40. u := base.H
  41. L := new(bn256.G1).Add(gRight.Commit(asLeft.vector), hLeft.Commit(bsRight.vector))
  42. L = new(bn256.G1).Add(L, new(bn256.G1).ScalarMult(u, cL))
  43. R := new(bn256.G1).Add(gLeft.Commit(asRight.vector), hRight.Commit(bsLeft.vector))
  44. R = new(bn256.G1).Add(R, new(bn256.G1).ScalarMult(u, cR))
  45. ip.ls = append(ip.ls, L)
  46. ip.rs = append(ip.rs, R)
  47. var input []byte
  48. input = append(input, convertbiginttobyte(prev_challenge)...)
  49. input = append(input, L.Marshal()...)
  50. input = append(input, R.Marshal()...)
  51. x := reducedhash(input)
  52. xinv := new(big.Int).ModInverse(x, bn256.Order)
  53. gPrime := gLeft.Times(xinv).Add(gRight.Times(x))
  54. hPrime := hLeft.Times(x).Add(hRight.Times(xinv))
  55. aPrime := asLeft.Times(x).Add(asRight.Times(xinv))
  56. bPrime := bsLeft.Times(xinv).Add(bsRight.Times(x))
  57. basePrime := NewGeneratorParams3(u, gPrime, hPrime)
  58. PPrimeL := new(bn256.G1).ScalarMult(L, new(big.Int).Mod(new(big.Int).Mul(x, x), bn256.Order)) //L * (x*x)
  59. PPrimeR := new(bn256.G1).ScalarMult(R, new(big.Int).Mod(new(big.Int).Mul(xinv, xinv), bn256.Order)) //R * (xinv*xinv)
  60. PPrime := new(bn256.G1).Add(PPrimeL, PPrimeR)
  61. PPrime = new(bn256.G1).Add(PPrime, P)
  62. ip.generateInnerProductProof(basePrime, PPrime, aPrime, bPrime, x)
  63. return
  64. }
  65. func (ip *InnerProduct) Verify(hs []*bn256.G1, u, P *bn256.G1, salt *big.Int, gp *GeneratorParams) bool {
  66. log_n := uint(len(ip.ls))
  67. if len(ip.ls) != len(ip.rs) { // length must be same
  68. return false
  69. }
  70. n := uint(math.Pow(2, float64(log_n)))
  71. o := salt
  72. var challenges []*big.Int
  73. for i := uint(0); i < log_n; i++ {
  74. var input []byte
  75. input = append(input, convertbiginttobyte(o)...)
  76. input = append(input, ip.ls[i].Marshal()...)
  77. input = append(input, ip.rs[i].Marshal()...)
  78. o = reducedhash(input)
  79. challenges = append(challenges, o)
  80. o_inv := new(big.Int).ModInverse(o, bn256.Order)
  81. PPrimeL := new(bn256.G1).ScalarMult(ip.ls[i], new(big.Int).Mod(new(big.Int).Mul(o, o), bn256.Order)) //L * (x*x)
  82. PPrimeR := new(bn256.G1).ScalarMult(ip.rs[i], new(big.Int).Mod(new(big.Int).Mul(o_inv, o_inv), bn256.Order)) //L * (x*x)
  83. PPrime := new(bn256.G1).Add(PPrimeL, PPrimeR)
  84. P = new(bn256.G1).Add(PPrime, P)
  85. }
  86. exp := new(big.Int).SetUint64(1)
  87. for i := uint(0); i < log_n; i++ {
  88. exp = new(big.Int).Mod(new(big.Int).Mul(exp, challenges[i]), bn256.Order)
  89. }
  90. exp_inv := new(big.Int).ModInverse(exp, bn256.Order)
  91. exponents := make([]*big.Int, n, n)
  92. exponents[0] = exp_inv // initializefirst element
  93. bits := make([]bool, n, n)
  94. for i := uint(0); i < n/2; i++ {
  95. for j := uint(0); (1<<j)+i < n; j++ {
  96. i1 := (1 << j) + i
  97. if !bits[i1] {
  98. temp := new(big.Int).Mod(new(big.Int).Mul(challenges[log_n-1-j], challenges[log_n-1-j]), bn256.Order)
  99. exponents[i1] = new(big.Int).Mod(new(big.Int).Mul(exponents[i], temp), bn256.Order)
  100. bits[i1] = true
  101. }
  102. }
  103. }
  104. var zeroes [64]byte
  105. gtemp := new(bn256.G1) // obtain zero element, this should be static and
  106. htemp := new(bn256.G1) // obtain zero element, this should be static and
  107. gtemp.Unmarshal(zeroes[:])
  108. htemp.Unmarshal(zeroes[:])
  109. for i := uint(0); i < n; i++ {
  110. gtemp = new(bn256.G1).Add(gtemp, new(bn256.G1).ScalarMult(gp.Gs.vector[i], exponents[i]))
  111. htemp = new(bn256.G1).Add(htemp, new(bn256.G1).ScalarMult(hs[i], exponents[n-1-i]))
  112. }
  113. gtemp = new(bn256.G1).ScalarMult(gtemp, ip.a)
  114. htemp = new(bn256.G1).ScalarMult(htemp, ip.b)
  115. utemp := new(bn256.G1).ScalarMult(u, new(big.Int).Mod(new(big.Int).Mul(ip.a, ip.b), bn256.Order))
  116. P_calculated := new(bn256.G1).Add(gtemp, htemp)
  117. P_calculated = new(bn256.G1).Add(P_calculated, utemp)
  118. // fmt.Printf("P %s\n",P.String())
  119. // fmt.Printf("P_calculated %s\n",P_calculated.String())
  120. if P_calculated.String() != P.String() { // need something better here
  121. panic("Faulty or invalid proof")
  122. return false
  123. }
  124. return true
  125. }