@@ -0,0 +1,27 @@ | |||
# Once edited, rename this file to .env | |||
PORT = 8080 | |||
GIN_MODE = "debug" | |||
DB_NAME = "dero_merchant" | |||
DB_USER = "postgres" | |||
DB_PASSWORD = "password_here" | |||
DB_HOST = "localhost" | |||
DB_PORT = 5432 | |||
REDIS_ADDRESS = "localhost:6379" | |||
DERO_NETWORK = "testnet" | |||
DERO_DAEMON_ADDRESS = "http://explorer.dero.io:30306" # URL Scheme must be included | |||
WALLETS_PATH = "./wallets/" | |||
PAYMENT_MAX_TTL = 60 | |||
PAYMENT_MIN_CONFIRMATIONS = 10 | |||
TEST_DB_NAME = "dero_merchant_test" | |||
TEST_DB_USER = "postgres" | |||
TEST_DB_PASSWORD = "password_here" | |||
TEST_DB_HOST = "localhost" | |||
TEST_DB_PORT = 5432 | |||
TEST_REDIS_ADDRESS = "localhost:6379" | |||
TEST_DERO_NETWORK = "testnet" | |||
TEST_DERO_DAEMON_ADDRESS = "http://explorer.dero.io:30306" # URL Scheme must be included | |||
TEST_WALLETS_PATH = "../test_wallets/" |
@@ -0,0 +1,24 @@ | |||
# Binaries for programs and plugins | |||
*.exe | |||
*.exe~ | |||
*.dll | |||
*.so | |||
*.dylib | |||
# Test binary, build with `go test -c` | |||
*.test | |||
# Output of the go coverage tool, specifically when used with LiteIDE | |||
*.out | |||
/.env | |||
/docker-compose-databases/docker-compose.yml | |||
/docker-compose-databases/postgres/data/* | |||
/run.sh | |||
/test.sh | |||
/dero-merchant | |||
/logs/ | |||
/wallets/ | |||
/test_wallets/ | |||
/coverage.out | |||
/TODO.txt |
@@ -0,0 +1,21 @@ | |||
MIT License | |||
Copyright (c) 2020 Peppinux | |||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||
of this software and associated documentation files (the "Software"), to deal | |||
in the Software without restriction, including without limitation the rights | |||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
copies of the Software, and to permit persons to whom the Software is | |||
furnished to do so, subject to the following conditions: | |||
The above copyright notice and this permission notice shall be included in all | |||
copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
SOFTWARE. |
@@ -0,0 +1,3 @@ | |||
# DERO Merchant | |||
Source code of DERO Merchant (merchant.dero.io) |
@@ -0,0 +1,349 @@ | |||
package api | |||
import ( | |||
"database/sql" | |||
"fmt" | |||
"math" | |||
"net/http" | |||
"strings" | |||
"time" | |||
"github.com/lib/pq" | |||
"github.com/pkg/errors" | |||
deroglobals "github.com/deroproject/derosuite/globals" | |||
"github.com/peppinux/dero-merchant/coingecko" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
"github.com/peppinux/dero-merchant/processor" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
) | |||
// Payment represents a payment made to a store | |||
type Payment struct { | |||
PaymentID string `json:"paymentID,omitempty"` | |||
Status string `json:"status,omitempty"` | |||
Currency string `json:"currency,omitempty"` | |||
CurrencyAmount float64 `json:"currencyAmount,omitempty"` | |||
ExchangeRate float64 `json:"exchangeRate,omitempty"` | |||
DeroAmount string `json:"deroAmount,omitempty"` | |||
AtomicDeroAmount uint64 `json:"atomicDeroAmount,omitempty"` | |||
IntegratedAddress string `json:"integratedAddress,omitempty"` | |||
CreationTime time.Time `json:"creationTime,omitempty"` | |||
TTL int `json:"ttl"` | |||
StoreID int `json:"-"` | |||
} | |||
// HasValidCurrency returns whether the currency of Payment is supported by CoinGecko API or not | |||
func (p *Payment) HasValidCurrency() bool { | |||
currency := strings.ToLower(p.Currency) | |||
if currency == "dero" { | |||
return true | |||
} | |||
// Check if currency is in cached set of supported currencies in Redis | |||
supported, _ := redis.IsSupportedCurrency(currency) | |||
if supported { | |||
return true | |||
} | |||
// If currency is not in cached set, get supported currencies from CoinGecko API | |||
currencies, err := coingecko.SupportedVsCurrencies() | |||
if err != nil { | |||
return false | |||
} | |||
// Update set in Redis | |||
go redis.SetSupportedCurrencies(currencies) | |||
// Check if currency is supported | |||
for _, c := range currencies { | |||
if currency == c { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
// HasValidCurrencyAmount checks if the amount of currency of Payment is a positive number | |||
func (p *Payment) HasValidCurrencyAmount() bool { | |||
return p.CurrencyAmount > 0 | |||
} | |||
func isUniqueIntegratedAddress(iaddr, payid string) (bool, error) { | |||
var existingPaymentID string | |||
err := postgres.DB.QueryRow(` | |||
SELECT payment_id | |||
FROM payments | |||
WHERE payment_id=$1 OR integrated_address=$2`, payid, iaddr). | |||
Scan(&existingPaymentID) | |||
if err != nil { | |||
if err == sql.ErrNoRows { // Integrated Address and PaymentID are unique | |||
return true, nil | |||
} | |||
return false, errors.Wrap(err, "cannot query database") | |||
} | |||
return false, nil | |||
} | |||
// GenerateUniqueIntegratedAddress returns an integrated address and its payment ID, that have never been used for any other payments before | |||
func GenerateUniqueIntegratedAddress(w *processor.StoreWallet) (iaddr, payid string, err error) { | |||
for { | |||
iaddr, payid = w.GenerateIntegratedAddress() | |||
isUnique, err := isUniqueIntegratedAddress(iaddr, payid) | |||
if err != nil { | |||
return "", "", errors.Wrap(err, "cannot check if generated integrated address is unique") | |||
} | |||
if isUnique { | |||
break | |||
} | |||
} | |||
return | |||
} | |||
// CalculateTTL calculates and updates Payment TTL based on the number of minutes passed from the creation of the payment | |||
func (p *Payment) CalculateTTL(minsFromCreation int) { | |||
if p.Status == processor.PaymentStatusPending { | |||
p.TTL = config.PaymentMaxTTL - minsFromCreation | |||
if p.TTL < 0 { | |||
p.TTL = 0 | |||
} | |||
} | |||
} | |||
// CreateNewPayment errors | |||
var ( | |||
ErrInvalidCurrency = errors.New("Invalid Param 'currency': required 3-4 chars long string") | |||
ErrInvalidAmount = errors.New("Invalid Param 'amount': required .12f float") | |||
) | |||
// CreateNewPayment returns a new Payment ready to be stored in DB and be listened to by processor | |||
func CreateNewPayment(currency string, currencyAmount float64, storeID int) (p *Payment, w *processor.StoreWallet, errCode int, err error) { | |||
p = &Payment{ | |||
Status: processor.PaymentStatusPending, | |||
TTL: config.PaymentMaxTTL, | |||
StoreID: storeID, | |||
} | |||
// Validate params | |||
p.Currency = strings.ToUpper(currency) | |||
if !p.HasValidCurrency() { | |||
return nil, nil, http.StatusUnprocessableEntity, ErrInvalidCurrency | |||
} | |||
p.CurrencyAmount = currencyAmount | |||
if !p.HasValidCurrencyAmount() { | |||
return nil, nil, http.StatusUnprocessableEntity, ErrInvalidAmount | |||
} | |||
if p.Currency == "DERO" { | |||
p.ExchangeRate = 1 | |||
p.DeroAmount = fmt.Sprintf("%.12f", p.CurrencyAmount) | |||
} else { | |||
// Get current exchange rate from CoinGecko API | |||
exchangeRate, err := coingecko.DeroPrice(p.Currency) // DERO value in payment currency. 1 DERO = x CURRENCY. Exchange Rate = x CURRENCY | |||
if err != nil { | |||
return nil, nil, http.StatusInternalServerError, errors.Wrap(err, "cannot get DERO price") | |||
} | |||
// Convert amount of currency to DERO | |||
p.ExchangeRate = exchangeRate | |||
deroAmount := p.CurrencyAmount / exchangeRate // 1 DERO : Exchange Rate = Dero Amount : Currency Amount => Dero Amount = 1 * Currency Amount / Exchange Rate | |||
p.DeroAmount = fmt.Sprintf("%.12f", deroAmount) | |||
} | |||
// Convert amount of DERO to atomic DERO | |||
p.AtomicDeroAmount, err = deroglobals.ParseAmount(p.DeroAmount) | |||
if err != nil { | |||
return nil, nil, http.StatusUnprocessableEntity, ErrInvalidAmount | |||
} | |||
w, err = processor.ActiveWallets.GetWalletFromStoreID(p.StoreID) | |||
if err != nil { | |||
return nil, nil, http.StatusInternalServerError, errors.Wrap(err, "cannot get wallet from Store ID") | |||
} | |||
err = w.DeroWallet.IsDaemonOnline() | |||
if err != nil { | |||
return nil, nil, http.StatusInternalServerError, errors.Wrap(err, "daemon offline") | |||
} | |||
p.IntegratedAddress, p.PaymentID, err = GenerateUniqueIntegratedAddress(w) | |||
if err != nil { | |||
return nil, nil, http.StatusInternalServerError, errors.Wrap(err, "cannot generate unique integrated address") | |||
} | |||
return | |||
} | |||
// Insert inserts a Payment into DB | |||
func (p *Payment) Insert() error { | |||
err := postgres.DB.QueryRow(` | |||
INSERT INTO payments (payment_id, status, currency, currency_amount, exchange_rate, dero_amount, atomic_dero_amount, integrated_address, store_id) | |||
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9) | |||
RETURNING creation_time`, p.PaymentID, p.Status, p.Currency, p.CurrencyAmount, p.ExchangeRate, p.DeroAmount, p.AtomicDeroAmount, p.IntegratedAddress, p.StoreID). | |||
Scan(&p.CreationTime) | |||
if err != nil { | |||
return errors.Wrap(err, "cannot query database") | |||
} | |||
return nil | |||
} | |||
// Payment(s) not found errors | |||
var ( | |||
ErrPaymentNotFound = errors.New("Payment not found") | |||
ErrPaymentsNotFound = errors.New("Payments not found") | |||
ErrNoPaymentsFound = errors.New("No payments found") | |||
ErrNoPaymentsFoundPage = errors.New("No payments found on this page") | |||
) | |||
// FetchPaymentFromID returns a Payment fetched from DB based on its Payment ID | |||
func FetchPaymentFromID(paymentID string, storeID int) (p *Payment, errCode int, err error) { | |||
p = &Payment{ | |||
PaymentID: paymentID, | |||
StoreID: storeID, | |||
} | |||
var minsFromCreation int | |||
err = postgres.DB.QueryRow(` | |||
SELECT status, currency, currency_amount, exchange_rate, dero_amount, atomic_dero_amount, integrated_address, creation_time, CEIL(EXTRACT('epoch' FROM NOW() - creation_time) / 60) | |||
FROM payments | |||
WHERE payment_id=$1 AND store_id=$2`, p.PaymentID, p.StoreID). | |||
Scan(&p.Status, &p.Currency, &p.CurrencyAmount, &p.ExchangeRate, &p.DeroAmount, &p.AtomicDeroAmount, &p.IntegratedAddress, &p.CreationTime, &minsFromCreation) | |||
if err != nil { | |||
if err == sql.ErrNoRows { | |||
return nil, http.StatusNotFound, ErrPaymentNotFound | |||
} | |||
return nil, http.StatusInternalServerError, errors.Wrap(err, "cannot query database") | |||
} | |||
p.CalculateTTL(minsFromCreation) | |||
return | |||
} | |||
// FetchPaymentsFromIDs returns a slice of Payments fetched from DB based on their Payment IDs | |||
func FetchPaymentsFromIDs(paymentIDs []string, storeID int) (ps []*Payment, errCode int, err error) { | |||
rows, err := postgres.DB.Query(` | |||
SELECT payment_id, status, currency, currency_amount, exchange_rate, dero_amount, atomic_dero_amount, integrated_address, creation_time, CEIL(EXTRACT('epoch' FROM NOW() - creation_time) / 60) | |||
FROM payments | |||
WHERE store_id=$1 AND payment_id = ANY($2)`, storeID, pq.Array(paymentIDs)) | |||
if err != nil { | |||
return nil, http.StatusInternalServerError, errors.Wrap(err, "cannot query database") | |||
} | |||
defer rows.Close() | |||
var minsFromCreation int | |||
for rows.Next() { | |||
var p Payment | |||
err := rows.Scan(&p.PaymentID, &p.Status, &p.Currency, &p.CurrencyAmount, &p.ExchangeRate, &p.DeroAmount, &p.AtomicDeroAmount, &p.IntegratedAddress, &p.CreationTime, &minsFromCreation) | |||
if err != nil { | |||
continue | |||
} | |||
p.CalculateTTL(minsFromCreation) | |||
ps = append(ps, &p) | |||
} | |||
if err := rows.Err(); err != nil { | |||
return nil, http.StatusInternalServerError, errors.Wrap(err, "cannot iterate over rows") | |||
} | |||
if len(ps) == 0 { | |||
return nil, http.StatusNotFound, ErrPaymentsNotFound | |||
} | |||
return | |||
} | |||
// FetchFilteredPayments returns a slice of Payments fetched from DB based on given filters | |||
func FetchFilteredPayments(storeID, limit, page int, sortBy, orderBy, statusFilter, currencyFilter string) (ps []*Payment, totalPayments, totalPages, errCode int, err error) { | |||
// Note: Input comes already sanitized from caller function GetPaymentsFromStoreID. | |||
// Fetch total number of filtered payments from DB | |||
err = postgres.DB.QueryRow(` | |||
SELECT COUNT(*) | |||
FROM payments | |||
WHERE store_id=$1 AND ($2='' OR status=LOWER($2)) AND ($3='' OR currency=UPPER($3))`, storeID, statusFilter, currencyFilter). | |||
Scan(&totalPayments) | |||
if err != nil { | |||
errCode = http.StatusInternalServerError | |||
err = errors.Wrap(err, "cannot query database") | |||
return | |||
} | |||
if totalPayments == 0 { | |||
errCode = http.StatusNotFound | |||
err = ErrNoPaymentsFound | |||
return | |||
} | |||
baseQuery := ` | |||
SELECT payment_id, status, currency, currency_amount, exchange_rate, dero_amount, atomic_dero_amount, integrated_address, creation_time, CEIL(EXTRACT('epoch' FROM NOW() - creation_time) / 60) | |||
FROM payments | |||
WHERE store_id=$1 AND ($2='' OR status=LOWER($2)) AND ($3='' OR currency=UPPER($3)) | |||
` | |||
orderByQuery := fmt.Sprintf(`ORDER BY %s %s `, sortBy, orderBy) // SQL Injection safe because params were previously validated. Could not use named parameters. | |||
limitQuery := "" | |||
if limit > 0 { | |||
offset := (page - 1) * limit | |||
limitQuery = fmt.Sprintf(`LIMIT %d OFFSET %d `, limit, offset) | |||
totalPages = int(math.Ceil(float64(totalPayments) / float64(limit))) | |||
} else { | |||
totalPages = 1 | |||
} | |||
if page > totalPages { | |||
errCode = http.StatusNotFound | |||
err = ErrNoPaymentsFoundPage | |||
return | |||
} | |||
// Fetch filtered payments from DB | |||
query := stringutil.Build(baseQuery, orderByQuery, limitQuery) | |||
rows, err := postgres.DB.Query(query, storeID, statusFilter, currencyFilter) | |||
if err != nil { | |||
errCode = http.StatusInternalServerError | |||
err = errors.Wrap(err, "cannot query database") | |||
return | |||
} | |||
defer rows.Close() | |||
var minsFromCreation int | |||
for rows.Next() { | |||
var p Payment | |||
err = rows.Scan(&p.PaymentID, &p.Status, &p.Currency, &p.CurrencyAmount, &p.ExchangeRate, &p.DeroAmount, &p.AtomicDeroAmount, &p.IntegratedAddress, &p.CreationTime, &minsFromCreation) | |||
if err != nil { | |||
errCode = http.StatusInternalServerError | |||
err = errors.Wrap(err, "cannot scan row") | |||
return | |||
} | |||
p.CalculateTTL(minsFromCreation) | |||
ps = append(ps, &p) | |||
} | |||
if err = rows.Err(); err != nil { | |||
errCode = http.StatusInternalServerError | |||
err = errors.Wrap(err, "cannot iterate over rows") | |||
return | |||
} | |||
return | |||
} |
@@ -0,0 +1,369 @@ | |||
package api | |||
import ( | |||
"net/http" | |||
"os" | |||
"testing" | |||
"github.com/alexedwards/argon2id" | |||
"github.com/stretchr/testify/suite" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
"github.com/peppinux/dero-merchant/processor" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
) | |||
type UserMock struct { | |||
ID int | |||
Username string | |||
Email string | |||
Password string | |||
HashedPassword string | |||
VerificationToken string | |||
Verified bool | |||
} | |||
type StoreMock struct { | |||
ID int | |||
Title string | |||
ViewKey string | |||
Webhook string | |||
WebhookSecretKey string | |||
APIKey string | |||
SecretKey string | |||
OwnerID int | |||
} | |||
type APITestSuite struct { | |||
suite.Suite | |||
mockUser *UserMock | |||
mockStore *StoreMock | |||
} | |||
func (suite *APITestSuite) SetupSuite() { | |||
err := config.LoadFromENV("../.env") | |||
if err != nil { | |||
panic(err) | |||
} | |||
redis.Pool = redis.NewPool(config.TestRedisAddress) | |||
err = redis.Ping() | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = redis.FlushAll() | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DB, err = postgres.Connect(config.TestDBName, config.TestDBUser, config.TestDBPassword, config.TestDBHost, config.TestDBPort, "disable") // TODO: Enable SSLMode? | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DropTables() | |||
postgres.CreateTablesIfNotExist() | |||
suite.mockUser = &UserMock{ | |||
Username: "Test user", | |||
Email: "test@user.com", | |||
Password: "foobarbaz", | |||
Verified: true, | |||
} | |||
suite.mockStore = &StoreMock{ | |||
Title: "Test store", | |||
ViewKey: "c53d44b598141c5527ab6a39e82e107d09620fda2af8c9bdc6cb06db2d4ff368cd73811194dbe53cbbe375fd3d9dc1ad1e334f56726d1289a8c096a13b76fd0c", | |||
Webhook: "", | |||
} | |||
u := suite.mockUser | |||
s := suite.mockStore | |||
u.HashedPassword, err = argon2id.CreateHash(u.Password, argon2id.DefaultParams) | |||
if err != nil { | |||
panic(err) | |||
} | |||
u.VerificationToken, err = stringutil.RandomBase64RawURLString(48) | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = postgres.DB.QueryRow(` | |||
INSERT INTO users (username, email, password, verification_token, email_verified) | |||
VALUES ($1, $2, $3, $4, $5) | |||
RETURNING id`, u.Username, u.Email, u.HashedPassword, u.VerificationToken, u.Verified). | |||
Scan(&u.ID) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.OwnerID = u.ID | |||
s.WebhookSecretKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.APIKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.SecretKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = postgres.DB.QueryRow(` | |||
INSERT INTO stores (title, wallet_view_key, webhook, webhook_secret_key, api_key, secret_key, owner_id) | |||
VALUES ($1, $2, $3, $4, $5, $6, $7) | |||
RETURNING id`, s.Title, s.ViewKey, s.Webhook, s.WebhookSecretKey, s.APIKey, s.SecretKey, s.OwnerID). | |||
Scan(&s.ID) | |||
if err != nil { | |||
panic(err) | |||
} | |||
config.DeroNetwork = config.TestDeroNetwork | |||
config.DeroDaemonAddress = config.TestDeroDaemonAddress | |||
processor.ActiveWallets = processor.NewStoresWallets() | |||
err = processor.SetupDaemonConnection() | |||
if err != nil { | |||
panic(err) | |||
} | |||
config.WalletsPath = config.TestWalletsPath | |||
err = processor.CreateWalletsDirectory() | |||
if err != nil { | |||
panic(err) | |||
} | |||
} | |||
func (suite *APITestSuite) TearDownSuite() { | |||
redis.FlushAll() | |||
redis.Pool.Close() | |||
postgres.DropTables() | |||
postgres.DB.Close() | |||
os.RemoveAll(config.TestWalletsPath) | |||
} | |||
func TestAPITestSuite(t *testing.T) { | |||
suite.Run(t, new(APITestSuite)) | |||
} | |||
func (suite *APITestSuite) TestHasValidCurrency() { | |||
testPayments := map[*Payment]bool{ | |||
&Payment{Currency: "DERO"}: true, | |||
&Payment{Currency: "usd"}: true, | |||
&Payment{Currency: "Eur"}: true, | |||
&Payment{Currency: "ABC"}: false, | |||
&Payment{Currency: "xyz"}: false, | |||
} | |||
for p, shouldValid := range testPayments { | |||
isValid := p.HasValidCurrency() | |||
suite.Equal(shouldValid, isValid) | |||
// Test is so quick the goroutine the caches currencies in Redis does not finish execution. | |||
// This is the reason why line 50 of api.go is not cover. | |||
// To cover it in spite of testing speed uncomment the following line: | |||
// time.Sleep(time.Millisecond * 100) | |||
} | |||
} | |||
func (suite *APITestSuite) TestHasValidCurrencyAmount() { | |||
testPayments := map[*Payment]bool{ | |||
&Payment{CurrencyAmount: 1}: true, | |||
&Payment{CurrencyAmount: 0.1}: true, | |||
&Payment{CurrencyAmount: 1234.567}: true, | |||
&Payment{CurrencyAmount: 0}: false, | |||
&Payment{CurrencyAmount: -0.1}: false, | |||
&Payment{CurrencyAmount: -100}: false, | |||
} | |||
for p, shouldValid := range testPayments { | |||
isValid := p.HasValidCurrencyAmount() | |||
suite.Equal(shouldValid, isValid) | |||
} | |||
} | |||
func (suite *APITestSuite) TestGenerateUniqueIntegratedAddress() { | |||
w, _ := processor.ActiveWallets.GetWalletFromStoreID(suite.mockStore.ID) | |||
count := 5 | |||
generatedAddresses := make([]string, count) | |||
for i := 0; i < count; i++ { | |||
iaddr, _, err := GenerateUniqueIntegratedAddress(w) | |||
suite.Nil(err) | |||
suite.NotContains(generatedAddresses, iaddr) | |||
generatedAddresses = append(generatedAddresses, iaddr) | |||
} | |||
} | |||
func (suite *APITestSuite) TestCalculateTTL() { | |||
oldMaxTTL := config.PaymentMaxTTL | |||
config.PaymentMaxTTL = 60 | |||
test := []struct { | |||
Payment *Payment | |||
MinsFromCreation int | |||
ExpectedTTL int | |||
}{ | |||
{Payment: &Payment{Status: processor.PaymentStatusPending}, MinsFromCreation: 0, ExpectedTTL: 60 - 0}, | |||
{Payment: &Payment{Status: processor.PaymentStatusPending}, MinsFromCreation: 1, ExpectedTTL: 60 - 1}, | |||
{Payment: &Payment{Status: processor.PaymentStatusPending}, MinsFromCreation: 20, ExpectedTTL: 60 - 20}, | |||
{Payment: &Payment{Status: processor.PaymentStatusPending}, MinsFromCreation: 60, ExpectedTTL: 0}, | |||
{Payment: &Payment{Status: processor.PaymentStatusPending}, MinsFromCreation: 100, ExpectedTTL: 0}, | |||
{Payment: &Payment{Status: processor.PaymentStatusPaid}, MinsFromCreation: 1000, ExpectedTTL: 0}, | |||
} | |||
for _, t := range test { | |||
t.Payment.CalculateTTL(t.MinsFromCreation) | |||
suite.Equal(t.ExpectedTTL, t.Payment.TTL) | |||
} | |||
config.PaymentMaxTTL = oldMaxTTL | |||
} | |||
func (suite *APITestSuite) TestPayments() { | |||
testPayments := []struct { | |||
Currency string | |||
CurrencyAmount float64 | |||
StoreID int | |||
Payment *Payment | |||
ExpectedErrCode int | |||
ExpectedErr error | |||
}{ | |||
{Currency: "DERO", CurrencyAmount: 50, StoreID: suite.mockStore.ID, ExpectedErrCode: 0, ExpectedErr: nil}, | |||
{Currency: "EUR", CurrencyAmount: 10, StoreID: suite.mockStore.ID, ExpectedErrCode: 0, ExpectedErr: nil}, | |||
{Currency: "ABC", CurrencyAmount: 10, StoreID: suite.mockStore.ID, ExpectedErrCode: http.StatusUnprocessableEntity, ExpectedErr: ErrInvalidCurrency}, | |||
{Currency: "USD", CurrencyAmount: 0, StoreID: suite.mockStore.ID, ExpectedErrCode: http.StatusUnprocessableEntity, ExpectedErr: ErrInvalidAmount}, | |||
} | |||
var ( | |||
errCode int | |||
err error | |||
validPaymentIDs []string | |||
) | |||
for _, p := range testPayments { | |||
// Test CreateNewPayment | |||
p.Payment, _, errCode, err = CreateNewPayment(p.Currency, p.CurrencyAmount, p.StoreID) | |||
suite.Equal(p.ExpectedErrCode, errCode) | |||
suite.Equal(p.ExpectedErr, err) | |||
if err == nil { | |||
validPaymentIDs = append(validPaymentIDs, p.Payment.PaymentID) | |||
// Test Insert | |||
err = p.Payment.Insert() | |||
suite.Nil(err) | |||
suite.NotZero(p.Payment.CreationTime) | |||
// Test FetchPaymentFromID | |||
_, errCode, err = FetchPaymentFromID(p.Payment.PaymentID, p.Payment.StoreID) | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
} | |||
} | |||
// Test FetchPaymentFromID (payment not found) | |||
invalidPaymentID, _ := stringutil.RandomHexString(32) | |||
_, errCode, err = FetchPaymentFromID(invalidPaymentID, suite.mockStore.ID) | |||
suite.Equal(http.StatusNotFound, errCode) | |||
suite.Equal(ErrPaymentNotFound, err) | |||
// Test FetchPaymentsFromIDs | |||
_, errCode, err = FetchPaymentsFromIDs(validPaymentIDs, suite.mockStore.ID) | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
// Test FetchPaymentsFromIDs (payments not found) | |||
_, errCode, err = FetchPaymentsFromIDs([]string{invalidPaymentID}, suite.mockStore.ID) | |||
suite.Equal(http.StatusNotFound, errCode) | |||
suite.Equal(ErrPaymentsNotFound, err) | |||
} | |||
func (suite *APITestSuite) TestFetchFilteredPayments() { | |||
storeID := suite.mockStore.ID | |||
mockPayments := []*Payment{} | |||
addMockPayment := func(status string, currency string, amount float64) { | |||
p, _, _, _ := CreateNewPayment(currency, amount, storeID) | |||
p.Status = status | |||
p.Insert() | |||
mockPayments = append(mockPayments, p) | |||
} | |||
// Test fetching payments before adding any | |||
payments, numPayments, numPages, errCode, err := FetchFilteredPayments(storeID, 0, 1, "creation_time", "desc", "", "") | |||
suite.Equal(http.StatusNotFound, errCode) | |||
suite.Equal(ErrNoPaymentsFound, err) | |||
suite.Equal(0, numPayments) | |||
suite.Equal(0, numPages) | |||
suite.Nil(payments) | |||
// 9 mock payments | |||
addMockPayment("pending", "DERO", 10) | |||
addMockPayment("paid", "DERO", 20) | |||
addMockPayment("pending", "DERO", 30) | |||
addMockPayment("paid", "DERO", 40) | |||
addMockPayment("paid", "USD", 50) | |||
addMockPayment("pending", "USD", 60) | |||
addMockPayment("expired", "USD", 70) | |||
addMockPayment("pending", "EUR", 80) | |||
addMockPayment("error", "EUR", 90) | |||
// Test fetching payments one by one | |||
payments, numPayments, numPages, errCode, err = FetchFilteredPayments(storeID, 1, 1, "creation_time", "desc", "", "") | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
suite.Equal(9, numPayments) | |||
suite.Equal(9, numPages) // Because limit = 1 | |||
suite.Equal(float64(90), payments[0].CurrencyAmount) // Last added payment | |||
// Test fetching all payments | |||
payments, numPayments, numPages, errCode, err = FetchFilteredPayments(storeID, 0, 1, "creation_time", "desc", "", "") | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
suite.Equal(9, numPayments) | |||
suite.Equal(1, numPages) // Because no limit | |||
suite.Equal(float64(90), payments[0].CurrencyAmount) // Last added payment | |||
suite.Equal(float64(10), payments[8].CurrencyAmount) // First added payment | |||
// Test fetching the 3rd page of the 9 payments divided in groups of 3 | |||
payments, numPayments, numPages, errCode, err = FetchFilteredPayments(storeID, 3, 3, "creation_time", "desc", "", "") | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
suite.Equal(9, numPayments) | |||
suite.Equal(3, numPages) | |||
suite.Equal(float64(30), payments[0].CurrencyAmount) | |||
suite.Equal(float64(20), payments[1].CurrencyAmount) | |||
suite.Equal(float64(10), payments[2].CurrencyAmount) | |||
// Test fetching the 5th page (out of range by 2) of the 9 payments divided in groups of 3 | |||
payments, numPayments, numPages, errCode, err = FetchFilteredPayments(storeID, 3, 5, "creation_time", "desc", "", "") | |||
suite.Equal(http.StatusNotFound, errCode) | |||
suite.Equal(ErrNoPaymentsFoundPage, err) | |||
suite.Equal(9, numPayments) | |||
suite.Equal(3, numPages) | |||
suite.Nil(payments) | |||
// Test fetching the *only* the *first* *paid* payment in *USD* | |||
payments, numPayments, numPages, errCode, err = FetchFilteredPayments(storeID, 1, 1, "creation_time", "asc", "paid", "USD") | |||
suite.Zero(errCode) | |||
suite.Nil(err) | |||
suite.Equal(1, numPayments) | |||
suite.Equal(1, numPages) | |||
suite.Equal(float64(50), payments[0].CurrencyAmount) | |||
} |
@@ -0,0 +1,222 @@ | |||
package api | |||
import ( | |||
"net/http" | |||
"github.com/gin-gonic/gin" | |||
"github.com/go-playground/validator" | |||
"github.com/peppinux/dero-merchant/httperror" | |||
) | |||
// PingGetHandler handles GET requests to /api/v1/ping | |||
func PingGetHandler(c *gin.Context) { | |||
c.JSON(http.StatusOK, gin.H{ | |||
"ping": "pong", | |||
}) | |||
} | |||
type paymentPostRequest struct { | |||
Currency string `json:"currency" binding:"required,max=4,min=3"` | |||
Amount float64 `json:"amount" binding:"required"` | |||
} | |||
var paymentPostFieldsErrors = map[string]string{ | |||
"Currency": ErrInvalidCurrency.Error(), | |||
"Amount": ErrInvalidAmount.Error(), | |||
} | |||
// PaymentPostHandler handles POST requests to /api/v1/payment | |||
func PaymentPostHandler(c *gin.Context) { | |||
// Get and validate request params | |||
var req paymentPostRequest | |||
err := c.ShouldBindJSON(&req) | |||
if err != nil { | |||
errs, ok := err.(validator.ValidationErrors) | |||
if !ok { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid request params") | |||
return | |||
} | |||
for _, err := range errs { | |||
httperror.Send(c, http.StatusUnprocessableEntity, paymentPostFieldsErrors[err.Field()]) | |||
return | |||
} | |||
} | |||
storeID := c.MustGet("storeID").(int) | |||
// Create Payment | |||
p, w, errCode, err := CreateNewPayment(req.Currency, req.Amount, storeID) | |||
if err != nil { | |||
if errCode == http.StatusInternalServerError { | |||
httperror.Send500(c, err, "Error creating new payment") | |||
return | |||
} | |||
httperror.Send(c, errCode, err.Error()) | |||
return | |||
} | |||
// Insert Payment into DB | |||
err = p.Insert() | |||
if httperror.Send500IfErr(c, err, "Error inserting Payment into DB") != nil { | |||
return | |||
} | |||
// Add Payment to wallet's pending payments | |||
err = w.AddPendingPayment(p.PaymentID, p.AtomicDeroAmount) | |||
if httperror.Send500IfErr(c, err, "Error adding pending payment to wallet") != nil { | |||
return | |||
} | |||
c.JSON(http.StatusCreated, p) | |||
} | |||
// PaymentGetHandler handles GET requests to /api/v1/payment/:payment_id | |||
func PaymentGetHandler(c *gin.Context) { | |||
paymentID := c.Param("payment_id") | |||
storeID := c.MustGet("storeID").(int) | |||
p, errCode, err := FetchPaymentFromID(paymentID, storeID) | |||
if err != nil { | |||
if errCode == http.StatusInternalServerError { | |||
httperror.Send500(c, err, "Error fetching payment from database") | |||
return | |||
} | |||
httperror.Send(c, errCode, err.Error()) | |||
return | |||
} | |||
c.JSON(http.StatusOK, p) | |||
} | |||
type paymentsPostRequest []string | |||
type paymentsPostResponse []*Payment | |||
// PaymentsPostHandler handles POST requests to /api/v1/payments | |||
func PaymentsPostHandler(c *gin.Context) { | |||
storeID := c.MustGet("storeID").(int) | |||
var ( | |||
paymentIDs paymentsPostRequest | |||
payments paymentsPostResponse | |||
) | |||
// Get request params | |||
err := c.ShouldBindJSON(&paymentIDs) | |||
if err != nil { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid request params") | |||
return | |||
} | |||
if len(paymentIDs) == 0 { | |||
httperror.Send(c, http.StatusBadRequest, "No Payment IDs submitted") | |||
return | |||
} | |||
payments, errCode, err := FetchPaymentsFromIDs(paymentIDs, storeID) | |||
if err != nil { | |||
if errCode == http.StatusInternalServerError { | |||
httperror.Send500(c, err, "Error fetching payments from database") | |||
return | |||
} | |||
httperror.Send(c, errCode, err.Error()) | |||
return | |||
} | |||
if len(payments) == 0 { | |||
httperror.Send(c, http.StatusNotFound, "Payments not found") | |||
return | |||
} | |||
c.JSON(http.StatusOK, payments) | |||
} | |||
type paymentsGetRequest struct { | |||
// Pagination | |||
Limit int `form:"limit,default=0" binding:"min=0"` | |||
Page int `form:"page,default=1" binding:"min=1"` | |||
// Sorting | |||
SortBy string `form:"sort_by,default=creation_time" binding:"eq=|eq=currency_amount|eq=exchange_rate|eq=atomic_dero_amount|eq=creation_time"` | |||
OrderBy string `form:"order_by,default=desc" binding:"eq=|eq=asc|eq=desc"` | |||
// Filtering | |||
Status string `form:"status,default=" binding:"eq=|eq=pending|eq=paid|eq=expired|eq=error"` | |||
Currency string `form:"currency,default=" binding:"max=4"` | |||
} | |||
var paymentsGetFieldsErrors = map[string]string{ | |||
"Limit": "Query param 'limit' not valid. Allowed values: (empty) or min 0", | |||
"Page": "Query param 'page' not valid. Allowed values: (empty) or min 1", | |||
"SortBy": "Query param 'sort_by' not valid. Allowed values: (empty), creation_time, currency_amount, exchange_rate, atomic_dero_amount", | |||
"OrderBy": "Query param 'order_by' not valid. Allowed values: (empty), asc, desc", | |||
"Status": "Query param 'status' not valid. Allowed values: (empty), pending, paid, expired, error", | |||
"Currency": "Query param 'currency' not valid. Allowed values: (empty) or max 4 characters", | |||
} | |||
type paymentsGetResponse struct { | |||
// Pagination | |||
Limit int `json:"limit"` | |||
Page int `json:"page,omitempty"` | |||
// Total number of filtered Payment(s) and pages (of "Limit" # of items) | |||
TotalPayments int `json:"totalPayments,omitempty"` | |||
TotalPages int `json:"totalPages,omitempty"` | |||
// Array of "limit" number of Payment(s) | |||
Payments []*Payment `json:"payments"` | |||
} | |||
// GetFilteredPaymentsFromStoreID is called by both PaymentsGetHandler (in this file) and PaymentsGetHandler (in webapp/store/handler.go) | |||
func GetFilteredPaymentsFromStoreID(c *gin.Context, storeID int) { | |||
var ( | |||
req paymentsGetRequest | |||
resp paymentsGetResponse | |||
) | |||
// Get and Validate URL Query params | |||
err := c.ShouldBindQuery(&req) | |||
if err != nil { | |||
errs, ok := err.(validator.ValidationErrors) | |||
if !ok { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid query params") | |||
return | |||
} | |||
for _, err := range errs { | |||
httperror.Send(c, http.StatusUnprocessableEntity, paymentsGetFieldsErrors[err.Field()]) | |||
return | |||
} | |||
} | |||
// Fill empty query params (necessary because default value in struct binding only accounts for unset params, not for params set to empty) | |||
if req.SortBy == "" { | |||
req.SortBy = "creation_time" | |||
} | |||
if req.OrderBy == "" { | |||
req.OrderBy = "desc" | |||
} | |||
resp.Limit = req.Limit | |||
resp.Page = req.Page | |||
var errCode int | |||
resp.Payments, resp.TotalPayments, resp.TotalPages, errCode, err = FetchFilteredPayments(storeID, req.Limit, req.Page, req.SortBy, req.OrderBy, req.Status, req.Currency) | |||
if err != nil { | |||
if errCode == http.StatusInternalServerError { | |||
httperror.Send500(c, err, "Error fetching filtered payments") | |||
return | |||
} | |||
httperror.Send(c, errCode, err.Error()) | |||
return | |||
} | |||
c.JSON(http.StatusOK, resp) | |||
} | |||
// PaymentsGetHandler handles GET requests to /api/v1/payments | |||
func PaymentsGetHandler(c *gin.Context) { | |||
storeID := c.MustGet("storeID").(int) | |||
GetFilteredPaymentsFromStoreID(c, storeID) | |||
} |
@@ -0,0 +1,120 @@ | |||
package auth | |||
import ( | |||
"bytes" | |||
"database/sql" | |||
"encoding/hex" | |||
"io/ioutil" | |||
"net/http" | |||
"github.com/gin-gonic/gin" | |||
"github.com/peppinux/dero-merchant/cryptoutil" | |||
"github.com/peppinux/dero-merchant/httperror" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
"github.com/peppinux/dero-merchant/redis" | |||
) | |||
// APIKeyAuth provides a middleware that rejects unauthenticated requests to the API | |||
func APIKeyAuth() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
// Get and validate X-API-Key header | |||
apiKey := c.GetHeader("X-API-Key") | |||
if len(apiKey) != 64 { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid Header X-API-Key: required 64 characters long string") | |||
return | |||
} | |||
// Fetch Store ID associated to (hashed) API Key from Redis | |||
hashedAPIKey := cryptoutil.HashStringToSHA256Hex(apiKey) | |||
storeID, err := redis.GetAPIKeyStore(hashedAPIKey) | |||
if err != nil { | |||
// If Store ID was not found in Redis, try fetching it from DB | |||
err := postgres.DB.QueryRow(` | |||
SELECT id | |||
FROM stores | |||
WHERE api_key=$1 AND removed=$2`, apiKey, false). | |||
Scan(&storeID) | |||
if err != nil { | |||
if err == sql.ErrNoRows { // No store associated to API Key was found | |||
httperror.Send(c, http.StatusForbidden, "Invalid API Key") | |||
} else { | |||
httperror.Send500(c, err, "Error querying database") | |||
} | |||
return | |||
} | |||
// Store value in Redis for quick retrieving in future requests | |||
redis.SetAPIKeyStore(hashedAPIKey, storeID) | |||
} | |||
c.Set("apiKey", apiKey) | |||
c.Set("storeID", storeID) | |||
c.Next() | |||
} | |||
} | |||
// SecretKeyAuth provides a middleware that rejects unauthorized requests to the API. Needs to be used in conjuction with APIKeyAuth | |||
func SecretKeyAuth() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
// Get and validate X-Signature header | |||
signature := c.GetHeader("X-Signature") | |||
if len(signature) != 64 { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid Header X-Signature: required 64 characters long SHA256 hex encoded string") | |||
return | |||
} | |||
apiKey := c.MustGet("apiKey").(string) | |||
hashedAPIKey := cryptoutil.HashStringToSHA256Hex(apiKey) | |||
// Fetch Secret Key associated to API Key from Redis | |||
secretKey, err := redis.GetAPIKeySecretKey(hashedAPIKey) | |||
if err != nil { | |||
// If Secret Key was not found in Redis, try fetching it from DB | |||
err = postgres.DB.QueryRow(` | |||
SELECT secret_key | |||
FROM stores | |||
WHERE api_key=$1 AND removed=$2`, apiKey, false). | |||
Scan(&secretKey) | |||
if err != nil { | |||
if err == sql.ErrNoRows { | |||
httperror.Send(c, http.StatusUnauthorized, "Invalid Signature") | |||
} else { | |||
httperror.Send500(c, err, "Error querying database") | |||
} | |||
return | |||
} | |||
// Store value in Redis for quick retrieving in future requests | |||
redis.SetAPIKeySecretKey(hashedAPIKey, secretKey) | |||
} | |||
body, _ := c.GetRawData() // Read request body from stream | |||
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) // Copy body back into stream for consumption by next route | |||
signatureBytes, err := hex.DecodeString(signature) | |||
if httperror.Send500IfErr(c, err, "Error decoding hex string") != nil { | |||
return | |||
} | |||
secretKeyBytes, err := hex.DecodeString(secretKey) | |||
if httperror.Send500IfErr(c, err, "Error decoding hex string") != nil { | |||
return | |||
} | |||
// Verify Signature | |||
validSignature, err := cryptoutil.ValidMAC(body, signatureBytes, secretKeyBytes) | |||
if httperror.Send500IfErr(c, err, "Error verifying signature") != nil { | |||
return | |||
} | |||
if !validSignature { | |||
httperror.Send(c, http.StatusUnauthorized, "Invalid Signature") | |||
return | |||
} | |||
c.Next() | |||
} | |||
} |
@@ -0,0 +1,232 @@ | |||
package auth | |||
import ( | |||
"encoding/hex" | |||
"net/http" | |||
"net/http/httptest" | |||
"strings" | |||
"testing" | |||
"github.com/alexedwards/argon2id" | |||
"github.com/gin-contrib/gzip" | |||
"github.com/gin-gonic/gin" | |||
"github.com/stretchr/testify/suite" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/cryptoutil" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
) | |||
type APIAuthUserMock struct { | |||
ID int | |||
Username string | |||
Email string | |||
Password string | |||
HashedPassword string | |||
VerificationToken string | |||
Verified bool | |||
} | |||
type APIAuthStoreMock struct { | |||
ID int | |||
Title string | |||
ViewKey string | |||
Webhook string | |||
WebhookSecretKey string | |||
APIKey string | |||
SecretKey string | |||
OwnerID int | |||
} | |||
type APIAuthTestSuite struct { | |||
suite.Suite | |||
mockUser *APIAuthUserMock | |||
mockStore *APIAuthStoreMock | |||
doRequest func(r *gin.Engine, body, apiKey, secretKey string) *httptest.ResponseRecorder | |||
} | |||
func (suite *APIAuthTestSuite) SetupSuite() { | |||
err := config.LoadFromENV("../.env") | |||
if err != nil { | |||
panic(err) | |||
} | |||
redis.Pool = redis.NewPool(config.TestRedisAddress) | |||
err = redis.Ping() | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = redis.FlushAll() | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DB, err = postgres.Connect(config.TestDBName, config.TestDBUser, config.TestDBPassword, config.TestDBHost, config.TestDBPort, "disable") // TODO: Enable SSLMode? | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DropTables() | |||
postgres.CreateTablesIfNotExist() | |||
suite.mockUser = &APIAuthUserMock{ | |||
Username: "Test user foo", | |||
Email: "foo@bar.baz", | |||
Password: "foobarbaz", | |||
Verified: true, | |||
} | |||
suite.mockStore = &APIAuthStoreMock{ | |||
Title: "Test store bar", | |||
ViewKey: "foobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbazfoobarbaz12", | |||
Webhook: "", | |||
} | |||
u := suite.mockUser | |||
s := suite.mockStore | |||
u.HashedPassword, err = argon2id.CreateHash(u.Password, argon2id.DefaultParams) | |||
if err != nil { | |||
panic(err) | |||
} | |||
u.VerificationToken, err = stringutil.RandomBase64RawURLString(48) | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = postgres.DB.QueryRow(` | |||
INSERT INTO users (username, email, password, verification_token, email_verified) | |||
VALUES ($1, $2, $3, $4, $5) | |||
RETURNING id`, u.Username, u.Email, u.HashedPassword, u.VerificationToken, u.Verified). | |||
Scan(&u.ID) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.OwnerID = u.ID | |||
s.WebhookSecretKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.APIKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
s.SecretKey, err = stringutil.RandomHexString(32) | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = postgres.DB.QueryRow(` | |||
INSERT INTO stores (title, wallet_view_key, webhook, webhook_secret_key, api_key, secret_key, owner_id) | |||
VALUES ($1, $2, $3, $4, $5, $6, $7) | |||
RETURNING id`, s.Title, s.ViewKey, s.Webhook, s.WebhookSecretKey, s.APIKey, s.SecretKey, s.OwnerID). | |||
Scan(&s.ID) | |||
if err != nil { | |||
panic(err) | |||
} | |||
// HTTP request reusable function | |||
suite.doRequest = func(r *gin.Engine, body, apiKey, secretKey string) *httptest.ResponseRecorder { | |||
w := httptest.NewRecorder() | |||
req, _ := http.NewRequest("POST", "/", strings.NewReader(body)) | |||
if apiKey != "" { | |||
req.Header.Add("X-API-Key", apiKey) | |||
} | |||
if secretKey != "" { | |||
key, _ := hex.DecodeString(secretKey) | |||
sign, _ := cryptoutil.SignMessage([]byte(body), key) | |||
hexSign := hex.EncodeToString(sign) | |||
req.Header.Add("X-Signature", hexSign) | |||
} | |||
r.ServeHTTP(w, req) | |||
return w | |||
} | |||
} | |||
func (suite *APIAuthTestSuite) TearDownSuite() { | |||
redis.FlushAll() | |||
redis.Pool.Close() | |||
postgres.DropTables() | |||
postgres.DB.Close() | |||
} | |||
func TestAPIAuthTestSuite(t *testing.T) { | |||
suite.Run(t, new(APIAuthTestSuite)) | |||
} | |||
func (suite *APIAuthTestSuite) TestAPIKeyAuth() { | |||
s := suite.mockStore | |||
// Setup router | |||
r := gin.Default() | |||
r.Use(gzip.Gzip(gzip.DefaultCompression)) | |||
r.Use(APIKeyAuth()) | |||
r.POST("/", func(c *gin.Context) { | |||
c.Status(http.StatusOK) | |||
}) | |||
// Valid API Key | |||
w := suite.doRequest(r, "", s.APIKey, "") | |||
suite.Equal(http.StatusOK, w.Code) | |||
// Valid API Key again in order to fetch it from Redis instead of Postgres | |||
w = suite.doRequest(r, "", s.APIKey, "") | |||
suite.Equal(http.StatusOK, w.Code) | |||
// No API Key | |||
w = suite.doRequest(r, "", "", "") | |||
suite.Equal(http.StatusBadRequest, w.Code) | |||
// Inexistent API Key | |||
k, _ := stringutil.RandomHexString(32) | |||
w = suite.doRequest(r, "", k, "") | |||
suite.Equal(http.StatusForbidden, w.Code) | |||
} | |||
func (suite *APIAuthTestSuite) TestSecretKeyAuth() { | |||
s := suite.mockStore | |||
// Setup router | |||
r := gin.Default() | |||
r.Use(gzip.Gzip(gzip.DefaultCompression)) | |||
r.Use(APIKeyAuth()) | |||
r.Use(SecretKeyAuth()) | |||
r.POST("/", func(c *gin.Context) { | |||
c.Status(http.StatusOK) | |||
}) | |||
testBody := `{"foo":"bar"}` | |||
// Valid Secret Key | |||
w := suite.doRequest(r, testBody, s.APIKey, s.SecretKey) | |||
suite.Equal(http.StatusOK, w.Code) | |||
// No Secret Key | |||
w = suite.doRequest(r, testBody, s.APIKey, "") | |||
suite.Equal(http.StatusBadRequest, w.Code) | |||
// Invalid Secret Key | |||
k, _ := stringutil.RandomHexString(32) | |||
w = suite.doRequest(r, testBody, s.APIKey, k) | |||
suite.Equal(http.StatusUnauthorized, w.Code) | |||
} |
@@ -0,0 +1,63 @@ | |||
package auth | |||
import ( | |||
"encoding/base64" | |||
"net/http" | |||
"strings" | |||
"github.com/alexedwards/argon2id" | |||
"github.com/gin-gonic/gin" | |||
"github.com/peppinux/dero-merchant/httperror" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
) | |||
// RequireUserPassword provides a middleware that requires users โ authenticated by their Session ID โ to provide additional authorization by confirming their password | |||
func RequireUserPassword() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
h := c.GetHeader("Authorization") | |||
if len(h) == 0 { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid Authorization header") | |||
return | |||
} | |||
splitPassword := strings.Split(h, " ") | |||
if len(splitPassword) != 2 || splitPassword[0] != "Password" { | |||
httperror.Send(c, http.StatusBadRequest, "Invalid Authorization header") | |||
return | |||
} | |||
base64Password := strings.TrimSpace(splitPassword[1]) | |||
passwordBytes, _ := base64.StdEncoding.DecodeString(base64Password) | |||
password := string(passwordBytes) | |||
if l := len(password); l < 8 || l > 64 { | |||
httperror.Send(c, http.StatusUnprocessableEntity, "Password needs to be between 8 and 64 characters long") | |||
return | |||
} | |||
userID := c.MustGet("session").(*Session).UserID | |||
var hashedPassword string | |||
err := postgres.DB.QueryRow(` | |||
SELECT password | |||
FROM users | |||
WHERE id=$1 AND email_verified=$2`, userID, true). | |||
Scan(&hashedPassword) | |||
if httperror.Send500IfErr(c, err, "Error querying databse") != nil { | |||
return | |||
} | |||
match, err := argon2id.ComparePasswordAndHash(password, hashedPassword) | |||
if httperror.Send500IfErr(c, err, "Error comparing passwords") != nil { | |||
return | |||
} | |||
if !match { | |||
httperror.Send(c, http.StatusUnauthorized, "Wrong password") | |||
return | |||
} | |||
c.Next() | |||
} | |||
} |
@@ -0,0 +1,197 @@ | |||
package auth | |||
import ( | |||
"encoding/base64" | |||
"fmt" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
"github.com/alexedwards/argon2id" | |||
"github.com/gin-contrib/gzip" | |||
"github.com/gin-gonic/gin" | |||
_ "github.com/lib/pq" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/postgres" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
"github.com/stretchr/testify/suite" | |||
) | |||
type PasswordAuthUserMock struct { | |||
ID int | |||
Username string | |||
Email string | |||
Password string | |||
Verified bool | |||
HashedPassword string | |||
VerificationToken string | |||
Base64Password string | |||
AuthHeader string | |||
SessionID string | |||
} | |||
type PasswordAuthTestSuite struct { | |||
suite.Suite | |||
mockUsers map[string]*PasswordAuthUserMock | |||
doRequest func(r *gin.Engine, sessionID, authHeader string) *httptest.ResponseRecorder | |||
} | |||
func (suite *PasswordAuthTestSuite) SetupSuite() { | |||
err := config.LoadFromENV("../.env") | |||
if err != nil { | |||
panic(err) | |||
} | |||
redis.Pool = redis.NewPool(config.TestRedisAddress) | |||
err = redis.Ping() | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = redis.FlushAll() | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DB, err = postgres.Connect(config.TestDBName, config.TestDBUser, config.TestDBPassword, config.TestDBHost, config.TestDBPort, "disable") // TODO: Enable SSLMode? | |||
if err != nil { | |||
panic(err) | |||
} | |||
postgres.DropTables() | |||
postgres.CreateTablesIfNotExist() | |||
suite.mockUsers = map[string]*PasswordAuthUserMock{ | |||
"valid": { | |||
Username: "Valid User", | |||
Email: "foo@bar.baz", | |||
Password: "foobarbaz", | |||
Verified: true, | |||
}, | |||
"unverified": { | |||
Username: "Not Verified", | |||
Email: "test@test.com", | |||
Password: "password", | |||
Verified: false, | |||
}, | |||
"inexistent": { | |||
Username: "I Do Not Exist", | |||
Email: "fake@email.it", | |||
Password: "123456789", | |||
}, | |||
} | |||
for _, u := range suite.mockUsers { | |||
u.HashedPassword, _ = argon2id.CreateHash(u.Password, argon2id.DefaultParams) | |||
u.VerificationToken, _ = stringutil.RandomBase64RawURLString(48) | |||
postgres.DB.QueryRow(` | |||
INSERT INTO users (username, email, password, verification_token, email_verified) | |||
VALUES ($1, $2, $3, $4, $5) | |||
RETURNING id`, u.Username, u.Email, u.HashedPassword, u.VerificationToken, u.Verified). | |||
Scan(&u.ID) | |||
} | |||
// Mock valid user's session | |||
suite.mockUsers["valid"].SessionID, _ = mockSessionID(suite.mockUsers["valid"].ID) | |||
// HTTP request reusable function | |||
suite.doRequest = func(r *gin.Engine, sessionID, authHeader string) *httptest.ResponseRecorder { | |||
w := httptest.NewRecorder() | |||
req, _ := http.NewRequest("POST", "/", nil) | |||
if sessionID != "" { | |||
req.AddCookie(&http.Cookie{ | |||
Name: "DM_SessionID", | |||
Value: sessionID, | |||
}) | |||
} | |||
if authHeader != "" { | |||
req.Header.Add("Authorization", authHeader) | |||
r.ServeHTTP(w, req) | |||
} | |||
r.ServeHTTP(w, req) | |||
return w | |||
} | |||
} | |||
func (suite *PasswordAuthTestSuite) TearDownSuite() { | |||
redis.FlushAll() | |||
redis.Pool.Close() | |||
postgres.DropTables() | |||
postgres.DB.Close() | |||
} | |||
func TestPasswordAuthTestSuite(t *testing.T) { | |||
suite.Run(t, new(PasswordAuthTestSuite)) | |||
} | |||
func (suite *PasswordAuthTestSuite) TestRequireUserPassword() { | |||
// Setup router | |||
r := gin.Default() | |||
r.Use(gzip.Gzip(gzip.DefaultCompression)) | |||
r.Use(SessionAuth()) | |||
r.Use(SessionAuthOrForbidden()) | |||
r.Use(RequireUserPassword()) | |||
r.POST("/", func(c *gin.Context) { | |||
c.Status(http.StatusOK) | |||
}) | |||
users := suite.mockUsers | |||
// Valid user with valid password, unverified user and inexistent user | |||
for k, u := range users { | |||
u.Base64Password = base64.StdEncoding.EncodeToString([]byte(u.Password)) | |||
u.AuthHeader = fmt.Sprintf("Password %s", u.Base64Password) | |||
w := suite.doRequest(r, u.SessionID, u.AuthHeader) | |||
switch k { | |||
case "valid": | |||
suite.Equal(http.StatusOK, w.Code) // Request went through | |||
case "unverified", "inexistent": | |||
suite.Equal(http.StatusForbidden, w.Code) // Request rejected by SessionAuthOrForbidden middleware | |||
} | |||
} | |||
u := users["valid"] | |||
// Valid user with valid Authorization token but no Session cookie | |||
w := suite.doRequest(r, "", u.AuthHeader) | |||
suite.Equal(http.StatusForbidden, w.Code) // Request rejected by SessionAuthOrForbidden middelware | |||
// Valid user with invalid Authorization token | |||
w = suite.doRequest(r, u.SessionID, "") | |||
suite.Equal(http.StatusBadRequest, w.Code) // Request rejected by RequireUserPassword middlware because of no Authorization header | |||
invalidHeader := fmt.Sprintf("ThisIsNotPassword %s", u.Base64Password) | |||
w = suite.doRequest(r, u.SessionID, invalidHeader) | |||
suite.Equal(http.StatusBadRequest, w.Code) // Request rejected by RequireUserPassword middlware because of invalid Authorization header | |||
invalidPass := base64.StdEncoding.EncodeToString([]byte("passwor")) // Less than 8 chars | |||
invalidHeader = fmt.Sprintf("Password %s", invalidPass) | |||
w = suite.doRequest(r, u.SessionID, invalidHeader) | |||
suite.Equal(http.StatusUnprocessableEntity, w.Code) // Request rejected by RequireUserPassword middleware because of invalid password format | |||
// Valid user with wrong password | |||
wrongPass := base64.StdEncoding.EncodeToString([]byte("foobarba")) // Misses last char | |||
invalidHeader = fmt.Sprintf("Password %s", wrongPass) | |||
w = suite.doRequest(r, u.SessionID, invalidHeader) | |||
suite.Equal(http.StatusUnauthorized, w.Code) // Request rejected by RequirePasswordMiddleware because of invalid password format | |||
} |
@@ -0,0 +1,100 @@ | |||
package auth | |||
import ( | |||
"github.com/pkg/errors" | |||
"github.com/peppinux/dero-merchant/cryptoutil" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
) | |||
// Session represents the cookie session of a user | |||
type Session struct { | |||
ID string | |||
SignedIn bool | |||
UserID int | |||
} | |||
// GetSessionFromCookie returns a new user Session loaded from Redis through the sessionid cookie | |||
func GetSessionFromCookie(cookie string) (s *Session) { | |||
s = &Session{} | |||
if len(cookie) != 64 { | |||
return | |||
} | |||
s.ID = cryptoutil.HashStringToSHA256Hex(cookie) | |||
var err error | |||
s.UserID, err = redis.GetSessionUser(s.ID) | |||
if err != nil { | |||
s.SignedIn = false | |||
} else { | |||
s.SignedIn = true | |||
} | |||
return | |||
} | |||
// Username returns the username of the user associated to the session | |||
func (s *Session) Username() (username string, err error) { | |||
username, err = redis.GetUserUsername(s.UserID) | |||
if err != nil { | |||
err = errors.Wrap(err, "cannot get user's username from Redis") | |||
} | |||
return | |||
} | |||
// Email returns the email of the user associated to the session | |||
func (s *Session) Email() (email string, err error) { | |||
email, err = redis.GetUserEmail(s.UserID) | |||
if err != nil { | |||
err = errors.Wrap(err, "cannot get user's email from Redis") | |||
} | |||
return | |||
} | |||
// StoresMap returns a map of the stores (ID: Title) of the user associated to the session | |||
func (s *Session) StoresMap() (storesMap map[int]string, err error) { | |||
stores, err := redis.GetUserStores(s.UserID) | |||
if err != nil { | |||
err = errors.Wrap(err, "cannot get user's stores from Redis") | |||
return | |||
} | |||
storesMap = make(map[int]string, len(stores)) | |||
for _, storeID := range stores { | |||
storesMap[storeID], err = redis.GetStoreTitle(storeID) | |||
if err != nil { | |||
err = errors.Wrap(err, "cannot get store's title") | |||
return | |||
} | |||
} | |||
return | |||
} | |||
func generateSessionID() (string, error) { | |||
return stringutil.RandomBase64RawURLString(48) | |||
} | |||
// GenerateUniqueSessionID generates a unique Session ID | |||
func GenerateUniqueSessionID() (sessionID string, err error) { | |||
for { | |||
// Generate Session ID | |||
sessionID, err = generateSessionID() | |||
if err != nil { | |||
err = errors.Wrap(err, "cannot generate session ID") | |||
return | |||
} | |||
// Get User ID associated to generated Session ID | |||
userID, _ := redis.GetSessionUser(sessionID) | |||
// If NO User ID is found, generated Session ID is unique, therefore return its value | |||
if userID == 0 { | |||
return | |||
} | |||
} | |||
} |
@@ -0,0 +1,165 @@ | |||
package auth | |||
import ( | |||
"testing" | |||
"github.com/stretchr/testify/suite" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/cryptoutil" | |||
"github.com/peppinux/dero-merchant/redis" | |||
"github.com/peppinux/dero-merchant/stringutil" | |||
) | |||
type SessionTestSuite struct { | |||
suite.Suite | |||
mockSessionID func(userID int) (sessionID, hash string) | |||
mockUnsetSessionID func() (sessionID, hash string) | |||
} | |||
func (suite *SessionTestSuite) SetupSuite() { | |||
err := config.LoadFromENV("../.env") | |||
if err != nil { | |||
panic(err) | |||
} | |||
redis.Pool = redis.NewPool(config.RedisAddress) | |||
err = redis.Ping() | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = redis.FlushAll() | |||
if err != nil { | |||
panic(err) | |||
} | |||
} | |||
func (suite *SessionTestSuite) TearDownSuite() { | |||
redis.FlushAll() | |||
redis.Pool.Close() | |||
} | |||
func TestSessionTestSuite(t *testing.T) { | |||
suite.Run(t, new(SessionTestSuite)) | |||
} | |||
func mockSessionID(userID int) (sessionID, hash string) { | |||
sessionID, _ = GenerateUniqueSessionID() | |||
hash = cryptoutil.HashStringToSHA256Hex(sessionID) | |||
redis.SetSessionUser(hash, userID) | |||
return | |||
} | |||
func mockUnsetSessionID() (sessionID, hash string) { | |||
sessionID, _ = GenerateUniqueSessionID() | |||
hash = cryptoutil.HashStringToSHA256Hex(sessionID) | |||
return | |||
} | |||
func (suite *SessionTestSuite) TestGenerateUniqueSessionID() { | |||
count := 5 | |||
generatedIDs := make([]string, count) | |||
for i := 0; i < count; i++ { | |||
id, err := GenerateUniqueSessionID() | |||
suite.Nil(err) | |||
suite.NotContains(generatedIDs, id) | |||
generatedIDs = append(generatedIDs, id) | |||
} | |||
} | |||
func (suite *SessionTestSuite) TestGetSessionFromCookie() { | |||
userID := 123 | |||
sessionID, hashedSessionID := mockSessionID(userID) | |||
s := GetSessionFromCookie(sessionID) | |||
suite.Equal(hashedSessionID, s.ID) | |||
suite.Equal(userID, s.UserID) | |||
suite.True(s.SignedIn) | |||
unsetSessionID, hashedUnsetSessionID := mockUnsetSessionID() | |||
s = GetSessionFromCookie(unsetSessionID) | |||
suite.Equal(hashedUnsetSessionID, s.ID) | |||
suite.Zero(s.UserID) | |||
suite.False(s.SignedIn) | |||
invalidSessionID, _ := stringutil.RandomBase64RawURLString(49) | |||
s = GetSessionFromCookie(invalidSessionID) | |||
suite.Equal(&Session{}, s) | |||
} | |||
func (suite *SessionTestSuite) TestUsername() { | |||
userID := 123 | |||
username := "foobar" | |||
redis.SetUserUsername(userID, username) | |||
sessionID, _ := mockSessionID(userID) | |||
session := GetSessionFromCookie(sessionID) | |||
u, err := session.Username() | |||
suite.Nil(err) | |||
suite.Equal(username, u) | |||
unsetSessionID, _ := GenerateUniqueSessionID() | |||
invalidSession := GetSessionFromCookie(unsetSessionID) | |||
u, err = invalidSession.Username() | |||
suite.NotNil(err) | |||
suite.Zero(u) | |||
} | |||
func (suite *SessionTestSuite) TestEmail() { | |||
userID := 123 | |||
email := "foo@bar.baz" | |||
redis.SetUserEmail(userID, email) | |||
sessionID, _ := mockSessionID(userID) | |||
session := GetSessionFromCookie(sessionID) | |||
e, err := session.Email() | |||
suite.Nil(err) | |||
suite.Equal(email, e) | |||
unsetSessionID, _ := GenerateUniqueSessionID() | |||
invalidSession := GetSessionFromCookie(unsetSessionID) | |||
e, err = invalidSession.Email() | |||
suite.NotNil(err) | |||
suite.Zero(e) | |||
} | |||
func (suite *SessionTestSuite) TestStoresMap() { | |||
userID := 123 | |||
storesMap := map[int]string{ | |||
2: "Test Store Foo", | |||
4: "Bar Test Store", | |||
8: "Baz baz baz 123", | |||
} | |||
sessionID, _ := mockSessionID(userID) | |||
for id, title := range storesMap { | |||
redis.AddUserStore(userID, id) | |||
redis.SetStoreTitle(id, title) | |||
} | |||
session := GetSessionFromCookie(sessionID) | |||
stores, err := session.StoresMap() | |||
suite.Nil(err) | |||
for id, title := range stores { | |||
suite.Equal(storesMap[id], title) | |||
} | |||
unsetSessionID, _ := GenerateUniqueSessionID() | |||
invalidSession := GetSessionFromCookie(unsetSessionID) | |||
stores, _ = invalidSession.StoresMap() | |||
suite.Equal(map[int]string{}, stores) | |||
} |
@@ -0,0 +1,64 @@ | |||
package auth | |||
import ( | |||
"net/http" | |||
"github.com/gin-gonic/gin" | |||
"github.com/peppinux/dero-merchant/httperror" | |||
) | |||
// SessionAuth provides a middleware to authenticate a user given their sessionid cookie | |||
func SessionAuth() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
cookie, _ := c.Cookie("DM_SessionID") | |||
s := GetSessionFromCookie(cookie) | |||
c.Set("session", s) | |||
c.Next() | |||
} | |||
} | |||
// SessionAuthOrRedirect provides a middleware that redirects user to the Sign In page if they are not authenticated | |||
func SessionAuthOrRedirect() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
s := c.MustGet("session").(*Session) | |||
if !s.SignedIn { | |||
c.Redirect(http.StatusFound, "/user/signin") | |||
c.Abort() | |||
return | |||
} | |||
c.Next() | |||
} | |||
} | |||
// SessionNotAuthOrRedirect provides a middleware that redirects user to Dashboard if they are already authenticated | |||
func SessionNotAuthOrRedirect() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
s := c.MustGet("session").(*Session) | |||
if s.SignedIn { | |||
c.Redirect(http.StatusFound, "/dashboard") | |||
c.Abort() | |||
return | |||
} | |||
c.Next() | |||
} | |||
} | |||
// SessionAuthOrForbidden provides a middleware that sends an error message to the user if they are not authenticated | |||
func SessionAuthOrForbidden() gin.HandlerFunc { | |||
return func(c *gin.Context) { | |||
s := c.MustGet("session").(*Session) | |||
if !s.SignedIn { | |||
httperror.Send(c, http.StatusForbidden, "Invalid session ID") | |||
return | |||
} | |||
c.Next() | |||
} | |||
} |
@@ -0,0 +1,179 @@ | |||
package auth | |||
import ( | |||
"encoding/json" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
"github.com/gin-contrib/gzip" | |||
"github.com/gin-gonic/gin" | |||
"github.com/stretchr/testify/suite" | |||
"github.com/peppinux/dero-merchant/config" | |||
"github.com/peppinux/dero-merchant/redis" | |||
) | |||
type SessionAuthTestSuite struct { | |||
suite.Suite | |||
doRequest func(r *gin.Engine, sessionID string) *httptest.ResponseRecorder | |||
} | |||
func (suite *SessionAuthTestSuite) SetupSuite() { | |||
err := config.LoadFromENV("../.env") | |||
if err != nil { | |||
panic(err) | |||
} | |||
redis.Pool = redis.NewPool(config.TestRedisAddress) | |||
err = redis.Ping() | |||
if err != nil { | |||
panic(err) | |||
} | |||
err = redis.FlushAll() | |||
if err != nil { | |||
panic(err) | |||
} | |||
// HTTP request reusable function | |||
suite.doRequest = func(r *gin.Engine, sessionID string) *httptest.ResponseRecorder { | |||
w := httptest.NewRecorder() | |||
req, _ := http.NewRequest("POST", "/", nil) | |||
req.AddCookie(&http.Cookie{ | |||
Name: "DM_SessionID", | |||
Value: sessionID, | |||
}) | |||
r.ServeHTTP(w, req) | |||
return w | |||
} | |||
} | |||
func (suite *SessionAuthTestSuite) TearDownSuite() { | |||
redis.FlushAll() | |||
redis.Pool.Close() | |||
} | |||
func TestSessionAuthTestSuite(t *testing.T) { | |||
suite.Run(t, new(SessionAuthTestSuite)) | |||
} | |||
func (suite *SessionAuthTestSuite) TestSessionAuth() { | |||
// Setup router | |||
r := gin.Default() | |||
r.Use(gzip.Gzip(gzip.DefaultCompression)) | |||
r.Use(SessionAuth()) | |||
r.POST("/", func(c *gin.Context) { | |||
s := c.MustGet("session").(*Session) | |||
c.JSON(http.StatusOK, s) | |||
}) | |||
// Mock valid session | |||
userID := 123 | |||
sessionID, hashedSessionID := mockSessionID(userID) | |||
w := suite.doRequest(r, sessionID) | |||
var resp *Session | |||
json.Unmarshal(w.Body.Bytes(), &resp) | |||
suite.Equal(hashedSessionID, resp.ID) | |||
suite.True(resp.SignedIn) | |||
suite.Equal(userID, resp.UserID) | |||
// Mock unset session | |||
sessionID, hashedSessionID = mockUnsetSessionID() | |||
w = suite.doRequest(r, sessionID) | |||
json.Unmarshal(w.Body.Bytes(), &resp) | |||
suite.Equal(hashedSessionID, resp.ID) | |||
suite.False(resp.SignedIn) | |||
suite.Zero(resp.UserID) | |||
} | |||
func (suite *SessionAuthTestSuite) TestSessionAuthOrRedirect() { | |||
// Setup router | |||
r := gin.Default() | |||
r.Use(gzip.Gzip(gzip.DefaultCompression)) | |||
r.Use(SessionAuth()) | |||
r.Use(SessionAuthOrRedirect()) | |||
r.POST("/", func(c *gin.Context) { | |||
c.Status(http.StatusOK) | |||
}) | |||
// Mock valid session | |||
userID := 123 | |||
sessionID, _ := mockSessionID(userID) | |||
w := suite.doRequest(r, sessionID) | |||
suite.Equal(http.StatusOK, w.Code) | |||
// Mock unset session | |||