UPDATE: Change auth logic
All checks were successful
Build and Release / release (push) Successful in 1m27s
All checks were successful
Build and Release / release (push) Successful in 1m27s
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"history-api/internal/models"
|
||||
"history-api/internal/services"
|
||||
"history-api/pkg/validator"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
@@ -136,6 +137,16 @@ func (h *AuthController) Signup(c fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
func (h *AuthController) getRefreshToken(c fiber.Ctx) string {
|
||||
auth := c.Get("Authorization")
|
||||
if auth != "" {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
|
||||
return c.Cookies("refresh_token")
|
||||
}
|
||||
|
||||
// RefreshToken godoc
|
||||
// @Summary Refresh session tokens
|
||||
// @Description Generate a new access token using a valid refresh token from context
|
||||
@@ -151,7 +162,15 @@ func (h *AuthController) RefreshToken(c fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := h.service.RefreshToken(ctx, c.Locals("uid").(string))
|
||||
tokenJwt := h.getRefreshToken(c)
|
||||
if tokenJwt == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: "Missing refresh token",
|
||||
})
|
||||
}
|
||||
|
||||
res, err := h.service.RefreshToken(ctx, c.Locals("uid").(string), tokenJwt)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
|
||||
@@ -124,6 +124,57 @@ func (m *MediaController) DeleteMedia(c fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
// BulkDeleteMedia godoc
|
||||
// @Summary Delete media
|
||||
// @Description Delete multiple media files by IDs
|
||||
// @Tags Media
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param body body request.MediaBulkDeleteDto true "Media IDs to delete"
|
||||
// @Success 200 {object} response.CommonResponse
|
||||
// @Failure 500 {object} response.CommonResponse
|
||||
// @Router /media [delete]
|
||||
func (m *MediaController) BulkDeleteMedia(c fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
claimsVal := c.Locals("user_claims")
|
||||
if claimsVal == nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: "Unauthorized",
|
||||
})
|
||||
}
|
||||
|
||||
claims, ok := claimsVal.(*response.JWTClaims)
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: "Invalid user claims",
|
||||
})
|
||||
}
|
||||
|
||||
dto := &request.MediaBulkDeleteDto{}
|
||||
if err := validator.ValidateBodyDto(c, dto); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
err := m.service.BulkDeleteMedia(ctx, claims, dto)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
return c.Status(fiber.StatusOK).JSON(response.CommonResponse{
|
||||
Status: true,
|
||||
Message: "Media deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// UploadServerSide godoc
|
||||
// @Summary Upload media (server-side)
|
||||
// @Description Upload media file through server
|
||||
|
||||
@@ -19,3 +19,7 @@ type SearchMediaDto struct {
|
||||
MinSize *int64 `json:"min_size" query:"min_size" validate:"omitempty,min=0"`
|
||||
MaxSize *int64 `json:"max_size" query:"max_size" validate:"omitempty,min=0,gtefield=MinSize"`
|
||||
}
|
||||
|
||||
type MediaBulkDeleteDto struct {
|
||||
MediaIDs []string `json:"media_ids" validate:"required,dive,uuid"`
|
||||
}
|
||||
|
||||
@@ -100,6 +100,16 @@ func (q *Queries) DeleteMedia(ctx context.Context, id pgtype.UUID) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteMedias = `-- name: DeleteMedias :exec
|
||||
DELETE FROM medias
|
||||
WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteMedias(ctx context.Context, dollar_1 []pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteMedias, dollar_1)
|
||||
return err
|
||||
}
|
||||
|
||||
const getMediaByID = `-- name: GetMediaByID :one
|
||||
SELECT id, user_id, storage_key, original_name, mime_type, size, file_metadata, created_at, updated_at FROM medias
|
||||
WHERE id = $1
|
||||
|
||||
@@ -31,7 +31,7 @@ func JwtAccess(userRepo repositories.UserRepository) fiber.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func JwtRefresh(userRepo repositories.UserRepository) fiber.Handler {
|
||||
func JwtRefresh() fiber.Handler {
|
||||
jwtRefreshSecret, err := config.GetConfig("JWT_REFRESH_SECRET")
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -40,7 +40,7 @@ func JwtRefresh(userRepo repositories.UserRepository) fiber.Handler {
|
||||
return jwtware.New(jwtware.Config{
|
||||
SigningKey: jwtware.SigningKey{Key: []byte(jwtRefreshSecret)},
|
||||
ErrorHandler: jwtError,
|
||||
SuccessHandler: jwtSuccess(userRepo),
|
||||
SuccessHandler: jwtSuccessRefresh(),
|
||||
Extractor: extractors.Chain(
|
||||
extractors.FromAuthHeader("Bearer"),
|
||||
extractors.FromCookie("refresh_token"),
|
||||
@@ -100,6 +100,38 @@ func jwtSuccess(userRepo repositories.UserRepository) fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func jwtSuccessRefresh() fiber.Handler {
|
||||
return func(c fiber.Ctx) error {
|
||||
unauthorized := func() error {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: "Invalid or missing token",
|
||||
})
|
||||
}
|
||||
|
||||
jwtToken := jwtware.FromContext(c)
|
||||
if jwtToken == nil {
|
||||
return unauthorized()
|
||||
}
|
||||
|
||||
claims, ok := jwtToken.Claims.(*response.JWTClaims)
|
||||
if !ok {
|
||||
return unauthorized()
|
||||
}
|
||||
|
||||
if slices.Contains(claims.Roles, constants.BANNED) {
|
||||
return c.Status(fiber.StatusForbidden).JSON(response.CommonResponse{
|
||||
Status: false,
|
||||
Message: "User account is banned",
|
||||
})
|
||||
}
|
||||
|
||||
c.Locals("uid", claims.UId)
|
||||
c.Locals("user_claims", claims)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func jwtError(c fiber.Ctx, err error) error {
|
||||
if err.Error() == "Missing or malformed JWT" {
|
||||
return c.Status(fiber.StatusBadRequest).
|
||||
|
||||
@@ -50,3 +50,12 @@ func MediaEntitiesToResponse(entities []*MediaEntity) []*response.MediaResponse
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
|
||||
func MediaEntitiesToStorageEntitye(entities []*MediaEntity) []*MediaStorageEntity {
|
||||
responses := make([]*MediaStorageEntity, len(entities))
|
||||
for i, entity := range entities {
|
||||
responses[i] = entity.ToStorageEntity()
|
||||
}
|
||||
return responses
|
||||
}
|
||||
@@ -16,10 +16,12 @@ import (
|
||||
|
||||
type MediaRepository interface {
|
||||
GetByID(ctx context.Context, id pgtype.UUID) (*models.MediaEntity, error)
|
||||
GetByIDs(ctx context.Context, ids []string) ([]*models.MediaEntity, error)
|
||||
GetByUserID(ctx context.Context, userId pgtype.UUID) ([]*models.MediaEntity, error)
|
||||
Search(ctx context.Context, params sqlc.SearchMediasParams) ([]*models.MediaEntity, error)
|
||||
Count(ctx context.Context, params sqlc.CountMediasParams) (int64, error)
|
||||
Delete(ctx context.Context, id pgtype.UUID) error
|
||||
BulkDelete(ctx context.Context, ids []pgtype.UUID) error
|
||||
Create(ctx context.Context, params sqlc.CreateMediaParams) (*models.MediaEntity, error)
|
||||
}
|
||||
|
||||
@@ -81,6 +83,10 @@ func (r *mediaRepository) getByIDsWithFallback(ctx context.Context, ids []string
|
||||
return medias, nil
|
||||
}
|
||||
|
||||
func (r *mediaRepository) GetByIDs(ctx context.Context, ids []string) ([]*models.MediaEntity, error) {
|
||||
return r.getByIDsWithFallback(ctx, ids)
|
||||
}
|
||||
|
||||
func (r *mediaRepository) GetByID(ctx context.Context, id pgtype.UUID) (*models.MediaEntity, error) {
|
||||
cacheId := fmt.Sprintf("media:id:%s", convert.UUIDToString(id))
|
||||
var media models.MediaEntity
|
||||
@@ -152,6 +158,23 @@ func (r *mediaRepository) Delete(ctx context.Context, id pgtype.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mediaRepository) BulkDelete(ctx context.Context, ids []pgtype.UUID) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
err := r.q.DeleteMedias(ctx, ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys := make([]string, len(ids))
|
||||
for i, id := range ids {
|
||||
keys[i] = fmt.Sprintf("media:id:%s", convert.UUIDToString(id))
|
||||
}
|
||||
_ = r.c.Del(ctx, keys...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mediaRepository) Search(ctx context.Context, params sqlc.SearchMediasParams) ([]*models.MediaEntity, error) {
|
||||
queryKey := r.generateQueryKey("media:search", params)
|
||||
var cachedIDs []string
|
||||
|
||||
@@ -12,7 +12,7 @@ func AuthRoutes(app *fiber.App, controller *controllers.AuthController, userRepo
|
||||
route := app.Group("/auth")
|
||||
route.Post("/signin", controller.Signin)
|
||||
route.Post("/signup", controller.Signup)
|
||||
route.Post("/refresh", middlewares.JwtRefresh(userRepo), controller.RefreshToken)
|
||||
route.Post("/refresh", middlewares.JwtRefresh(), controller.RefreshToken)
|
||||
route.Post("/token/create", controller.CreateToken)
|
||||
route.Post("/token/verify", controller.VerifyToken)
|
||||
route.Post("/forgot-password", controller.ForgotPassword)
|
||||
|
||||
@@ -17,7 +17,12 @@ func MediaRoutes(app *fiber.App, controller *controllers.MediaController, userRe
|
||||
middlewares.RequireAnyRole(constants.ADMIN, constants.MOD),
|
||||
controller.SearchMedia,
|
||||
)
|
||||
|
||||
route.Delete(
|
||||
"/",
|
||||
middlewares.JwtAccess(userRepo),
|
||||
controller.BulkDeleteMedia,
|
||||
)
|
||||
|
||||
route.Post(
|
||||
"/upload",
|
||||
middlewares.JwtAccess(userRepo),
|
||||
|
||||
@@ -38,7 +38,7 @@ type AuthService interface {
|
||||
VerifyToken(ctx context.Context, dto *request.VerifyTokenDto) (*response.VerifyTokenResponse, error)
|
||||
CreateToken(ctx context.Context, dto *request.CreateTokenDto) error
|
||||
SigninWithGoogle(ctx context.Context, dto *request.SigninWithGoogleDto) (*response.AuthResponse, error)
|
||||
RefreshToken(ctx context.Context, id string) (*response.AuthResponse, error)
|
||||
RefreshToken(ctx context.Context, id string, refreshToken string) (*response.AuthResponse, error)
|
||||
}
|
||||
|
||||
type authService struct {
|
||||
@@ -203,7 +203,7 @@ func (a *authService) Logout(ctx context.Context, userId string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *authService) RefreshToken(ctx context.Context, id string) (*response.AuthResponse, error) {
|
||||
func (a *authService) RefreshToken(ctx context.Context, id string, refreshToken string) (*response.AuthResponse, error) {
|
||||
var pgID pgtype.UUID
|
||||
err := pgID.Scan(id)
|
||||
if err != nil {
|
||||
@@ -213,6 +213,11 @@ func (a *authService) RefreshToken(ctx context.Context, id string) (*response.Au
|
||||
if err != nil {
|
||||
return nil, fiber.NewError(fiber.StatusInternalServerError, "Invalid user data")
|
||||
}
|
||||
|
||||
if user.RefreshToken != refreshToken {
|
||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "Invalid refresh token")
|
||||
}
|
||||
|
||||
roles := models.RolesEntityToRoleConstant(user.Roles)
|
||||
|
||||
if slices.Contains(roles, constants.BANNED) {
|
||||
|
||||
@@ -32,6 +32,7 @@ type MediaService interface {
|
||||
GetMediaByUserID(ctx context.Context, userId string) ([]*response.MediaResponse, error)
|
||||
SearchMedia(ctx context.Context, dto *request.SearchMediaDto) (*response.PaginatedResponse, error)
|
||||
DeleteMedia(ctx context.Context, claims *response.JWTClaims, mediaId string) error
|
||||
BulkDeleteMedia(ctx context.Context, claims *response.JWTClaims, dto *request.MediaBulkDeleteDto) error
|
||||
UploadServerSide(ctx context.Context, userId string, fileHeader *multipart.FileHeader) (*response.MediaResponse, error)
|
||||
GeneratePresignedURL(ctx context.Context, userId string, dto *request.PreSignedDto) (*response.PreSignedResponse, error)
|
||||
PreSignedCompleted(ctx context.Context, userId string, dto *request.PreSignedCompleteDto) (*response.MediaResponse, error)
|
||||
@@ -88,6 +89,39 @@ func (m *mediaService) DeleteMedia(ctx context.Context, claims *response.JWTClai
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mediaService) BulkDeleteMedia(ctx context.Context, claims *response.JWTClaims, dto *request.MediaBulkDeleteDto) error {
|
||||
listMedia, err := m.mediaRepo.GetByIDs(ctx, dto.MediaIDs)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
shoudDelete := false
|
||||
if slices.Contains(claims.Roles, constants.ADMIN) || slices.Contains(claims.Roles, constants.MOD) {
|
||||
shoudDelete = true
|
||||
}
|
||||
listMediaIds := make([]pgtype.UUID, len(listMedia))
|
||||
listMediaStorageEntities := make([]*models.MediaStorageEntity, len(listMedia))
|
||||
for _, media := range listMedia {
|
||||
if media.UserID != claims.UId && !shoudDelete {
|
||||
return fiber.NewError(fiber.StatusForbidden, "You don't have permission to delete this media")
|
||||
}
|
||||
id, err := convert.StringToUUID(media.ID)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
listMediaIds = append(listMediaIds, id)
|
||||
listMediaStorageEntities = append(listMediaStorageEntities, media.ToStorageEntity())
|
||||
}
|
||||
|
||||
err = m.mediaRepo.BulkDelete(ctx, listMediaIds)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
m.c.PublishTask(ctx, constants.StreamStorageName, constants.TaskTypeBulkDeleteMedia, listMediaStorageEntities)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mediaService) GetMediaByID(ctx context.Context, id string) (*response.MediaResponse, error) {
|
||||
mediaId, err := convert.StringToUUID(id)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user