Files
History_Api/internal/repositories/userRepository.go
AzenKain f04441bf2a
Some checks failed
Build and Release / release (push) Failing after 1m25s
UPDATE: Auth module, User module
2026-03-30 00:27:57 +07:00

482 lines
13 KiB
Go

package repositories
import (
"context"
"crypto/md5"
"encoding/json"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"history-api/internal/gen/sqlc"
"history-api/internal/models"
"history-api/pkg/cache"
"history-api/pkg/constants"
"history-api/pkg/convert"
)
type UserRepository interface {
GetByID(ctx context.Context, id pgtype.UUID) (*models.UserEntity, error)
GetByIDWithoutDeleted(ctx context.Context, id pgtype.UUID) (*models.UserEntity, error)
GetByEmail(ctx context.Context, email string) (*models.UserEntity, error)
All(ctx context.Context, params sqlc.GetUsersParams) ([]*models.UserEntity, error)
Search(ctx context.Context, params sqlc.SearchUsersParams) ([]*models.UserEntity, error)
UpsertUser(ctx context.Context, params sqlc.UpsertUserParams) (*models.UserEntity, error)
CreateProfile(ctx context.Context, params sqlc.CreateUserProfileParams) (*models.UserProfileSimple, error)
UpdateProfile(ctx context.Context, params sqlc.UpdateUserProfileParams) (*models.UserEntity, error)
UpdatePassword(ctx context.Context, params sqlc.UpdateUserPasswordParams) error
UpdateRefreshToken(ctx context.Context, params sqlc.UpdateUserRefreshTokenParams) error
GetTokenVersion(ctx context.Context, id pgtype.UUID) (int32, error)
UpdateTokenVersion(ctx context.Context, params sqlc.UpdateTokenVersionParams) error
Delete(ctx context.Context, id pgtype.UUID) error
Restore(ctx context.Context, id pgtype.UUID) error
}
type userRepository struct {
q *sqlc.Queries
c cache.Cache
}
func NewUserRepository(db sqlc.DBTX, c cache.Cache) UserRepository {
return &userRepository{
q: sqlc.New(db),
c: c,
}
}
func (r *userRepository) generateQueryKey(prefix string, params any) string {
b, _ := json.Marshal(params)
hash := fmt.Sprintf("%x", md5.Sum(b))
return fmt.Sprintf("%s:%s", prefix, hash)
}
func (r *userRepository) getByIDsWithFallback(ctx context.Context, ids []string) ([]*models.UserEntity, error) {
if len(ids) == 0 {
return []*models.UserEntity{}, nil
}
keys := make([]string, len(ids))
for i, id := range ids {
keys[i] = fmt.Sprintf("user:id:%s", id)
}
raws := r.c.MGet(ctx, keys...)
var users []*models.UserEntity
missingUsersToCache := make(map[string]any)
for i, b := range raws {
if len(b) > 0 {
var u models.UserEntity
if err := json.Unmarshal(b, &u); err == nil {
users = append(users, &u)
}
} else {
pgId := pgtype.UUID{}
err := pgId.Scan(ids[i])
if err != nil {
continue
}
dbUser, err := r.GetByID(ctx, pgId)
if err == nil && dbUser != nil {
users = append(users, dbUser)
missingUsersToCache[keys[i]] = dbUser
}
}
}
if len(missingUsersToCache) > 0 {
_ = r.c.MSet(ctx, missingUsersToCache, constants.NormalCacheDuration)
}
return users, nil
}
func (r *userRepository) GetByID(ctx context.Context, id pgtype.UUID) (*models.UserEntity, error) {
cacheId := fmt.Sprintf("user:id:%s", convert.UUIDToString(id))
var user models.UserEntity
err := r.c.Get(ctx, cacheId, &user)
if err == nil {
return &user, nil
}
row, err := r.q.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
user = models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
}
if err := user.ParseRoles(row.Roles); err != nil {
return nil, err
}
if err := user.ParseProfile(row.Profile); err != nil {
return nil, err
}
_ = r.c.Set(ctx, cacheId, user, constants.NormalCacheDuration)
return &user, nil
}
func (r *userRepository) GetByIDWithoutDeleted(ctx context.Context, id pgtype.UUID) (*models.UserEntity, error) {
cacheId := fmt.Sprintf("user:deleted:id:%s", convert.UUIDToString(id))
var user models.UserEntity
err := r.c.Get(ctx, cacheId, &user)
if err == nil {
return &user, nil
}
row, err := r.q.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
user = models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
}
if err := user.ParseRoles(row.Roles); err != nil {
return nil, err
}
if err := user.ParseProfile(row.Profile); err != nil {
return nil, err
}
_ = r.c.Set(ctx, cacheId, user, constants.NormalCacheDuration)
return &user, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*models.UserEntity, error) {
cacheId := fmt.Sprintf("user:email:%s", email)
var user models.UserEntity
err := r.c.Get(ctx, cacheId, &user)
if err == nil {
return &user, nil
}
row, err := r.q.GetUserByEmail(ctx, email)
if err != nil {
return nil, err
}
user = models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
}
if err := user.ParseRoles(row.Roles); err != nil {
return nil, err
}
if err := user.ParseProfile(row.Profile); err != nil {
return nil, err
}
_ = r.c.Set(ctx, cacheId, user, constants.NormalCacheDuration)
return &user, nil
}
func (r *userRepository) UpsertUser(ctx context.Context, params sqlc.UpsertUserParams) (*models.UserEntity, error) {
row, err := r.q.UpsertUser(ctx, params)
if err != nil {
return nil, err
}
go func() {
bgCtx := context.Background()
_ = r.c.DelByPattern(bgCtx, "user:all*")
_ = r.c.DelByPattern(bgCtx, "user:search*")
}()
return &models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
Roles: make([]*models.RoleSimple, 0),
}, nil
}
func (r *userRepository) UpdateProfile(ctx context.Context, params sqlc.UpdateUserProfileParams) (*models.UserEntity, error) {
user, err := r.GetByID(ctx, params.UserID)
if err != nil {
return nil, err
}
row, err := r.q.UpdateUserProfile(ctx, params)
if err != nil {
return nil, err
}
profile := models.UserProfileSimple{
DisplayName: convert.TextToString(row.DisplayName),
FullName: convert.TextToString(row.FullName),
AvatarUrl: convert.TextToString(row.AvatarUrl),
Bio: convert.TextToString(row.Bio),
Location: convert.TextToString(row.Location),
Website: convert.TextToString(row.Website),
CountryCode: convert.TextToString(row.CountryCode),
Phone: convert.TextToString(row.Phone),
}
user.Profile = &profile
mapCache := map[string]any{
fmt.Sprintf("user:email:%s", user.Email): user,
fmt.Sprintf("user:id:%s", user.ID): user,
}
_ = r.c.MSet(ctx, mapCache, constants.NormalCacheDuration)
return user, nil
}
func (r *userRepository) CreateProfile(ctx context.Context, params sqlc.CreateUserProfileParams) (*models.UserProfileSimple, error) {
row, err := r.q.CreateUserProfile(ctx, params)
if err != nil {
return nil, err
}
return &models.UserProfileSimple{
DisplayName: convert.TextToString(row.DisplayName),
FullName: convert.TextToString(row.FullName),
AvatarUrl: convert.TextToString(row.AvatarUrl),
Bio: convert.TextToString(row.Bio),
Location: convert.TextToString(row.Location),
Website: convert.TextToString(row.Website),
CountryCode: convert.TextToString(row.CountryCode),
Phone: convert.TextToString(row.Phone),
}, nil
}
func (r *userRepository) All(ctx context.Context, params sqlc.GetUsersParams) ([]*models.UserEntity, error) {
queryKey := r.generateQueryKey("user:all", params)
var cachedIDs []string
if err := r.c.Get(ctx, queryKey, &cachedIDs); err == nil && len(cachedIDs) > 0 {
return r.getByIDsWithFallback(ctx, cachedIDs)
}
rows, err := r.q.GetUsers(ctx, params)
if err != nil {
return nil, err
}
var users []*models.UserEntity
var ids []string
usersToCache := make(map[string]any)
for _, row := range rows {
user := &models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
}
if err := user.ParseRoles(row.Roles); err != nil {
return nil, err
}
if err := user.ParseProfile(row.Profile); err != nil {
return nil, err
}
users = append(users, user)
ids = append(ids, user.ID)
usersToCache[fmt.Sprintf("user:id:%s", user.ID)] = user
}
if len(usersToCache) > 0 {
_ = r.c.MSet(ctx, usersToCache, constants.NormalCacheDuration)
}
if len(ids) > 0 {
_ = r.c.Set(ctx, queryKey, ids, constants.ListCacheDuration)
}
return users, nil
}
func (r *userRepository) Search(ctx context.Context, params sqlc.SearchUsersParams) ([]*models.UserEntity, error) {
queryKey := r.generateQueryKey("user:search", params)
var cachedIDs []string
if err := r.c.Get(ctx, queryKey, &cachedIDs); err == nil && len(cachedIDs) > 0 {
return r.getByIDsWithFallback(ctx, cachedIDs)
}
rows, err := r.q.SearchUsers(ctx, params)
if err != nil {
return nil, err
}
var users []*models.UserEntity
var ids []string
usersToCache := make(map[string]any)
for _, row := range rows {
user := &models.UserEntity{
ID: convert.UUIDToString(row.ID),
Email: row.Email,
PasswordHash: convert.TextToString(row.PasswordHash),
TokenVersion: row.TokenVersion,
IsDeleted: row.IsDeleted,
CreatedAt: convert.TimeToPtr(row.CreatedAt),
UpdatedAt: convert.TimeToPtr(row.UpdatedAt),
}
if err := user.ParseRoles(row.Roles); err != nil {
return nil, err
}
if err := user.ParseProfile(row.Profile); err != nil {
return nil, err
}
users = append(users, user)
ids = append(ids, user.ID)
usersToCache[fmt.Sprintf("user:id:%s", user.ID)] = user
}
if len(usersToCache) > 0 {
_ = r.c.MSet(ctx, usersToCache, constants.NormalCacheDuration)
}
if len(ids) > 0 {
_ = r.c.Set(ctx, queryKey, ids, constants.ListCacheDuration)
}
return users, nil
}
func (r *userRepository) Delete(ctx context.Context, id pgtype.UUID) error {
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
err = r.q.DeleteUser(ctx, id)
if err != nil {
return err
}
_ = r.c.Del(
ctx,
fmt.Sprintf("user:id:%s", user.ID),
fmt.Sprintf("user:email:%s", user.Email),
fmt.Sprintf("user:token:%s", user.ID),
)
return nil
}
func (r *userRepository) Restore(ctx context.Context, id pgtype.UUID) error {
err := r.q.RestoreUser(ctx, id)
if err != nil {
return err
}
_ = r.c.Del(ctx, fmt.Sprintf("user:id:%s", convert.UUIDToString(id)))
return nil
}
func (r *userRepository) GetTokenVersion(ctx context.Context, id pgtype.UUID) (int32, error) {
cacheId := fmt.Sprintf("user:token:%s", convert.UUIDToString(id))
var token int32
err := r.c.Get(ctx, cacheId, &token)
if err == nil {
return token, nil
}
raw, err := r.q.GetTokenVersion(ctx, id)
if err != nil {
return 0, err
}
_ = r.c.Set(ctx, cacheId, raw, constants.NormalCacheDuration)
return raw, nil
}
func (r *userRepository) UpdateTokenVersion(ctx context.Context, params sqlc.UpdateTokenVersionParams) error {
err := r.q.UpdateTokenVersion(ctx, params)
if err != nil {
return err
}
cacheId := fmt.Sprintf("user:token:%s", convert.UUIDToString(params.ID))
_ = r.c.Set(ctx, cacheId, params.TokenVersion, constants.NormalCacheDuration)
return nil
}
func (r *userRepository) UpdatePassword(ctx context.Context, params sqlc.UpdateUserPasswordParams) error {
user, err := r.GetByID(ctx, params.ID)
if err != nil {
return err
}
err = r.q.UpdateUserPassword(ctx, params)
if err != nil {
return err
}
err = r.UpdateTokenVersion(ctx, sqlc.UpdateTokenVersionParams{
ID: params.ID,
TokenVersion: user.TokenVersion + 1,
})
if err != nil {
return err
}
user.PasswordHash = convert.TextToString(params.PasswordHash)
user.TokenVersion += 1
mapCache := map[string]any{
fmt.Sprintf("user:email:%s", user.Email): user,
fmt.Sprintf("user:id:%s", user.ID): user,
fmt.Sprintf("user:token:%s", user.ID): user.TokenVersion,
}
_ = r.c.MSet(ctx, mapCache, constants.NormalCacheDuration)
return nil
}
func (r *userRepository) UpdateRefreshToken(ctx context.Context, params sqlc.UpdateUserRefreshTokenParams) error {
user, err := r.GetByID(ctx, params.ID)
if err != nil {
return err
}
err = r.q.UpdateUserRefreshToken(ctx, params)
if err != nil {
return err
}
user.RefreshToken = convert.TextToString(params.RefreshToken)
mapCache := map[string]any{
fmt.Sprintf("user:email:%s", user.Email): user,
fmt.Sprintf("user:id:%s", user.ID): user,
fmt.Sprintf("user:token:%s", user.ID): user.TokenVersion,
}
_ = r.c.MSet(ctx, mapCache, constants.NormalCacheDuration)
return nil
}