This commit is contained in:
Raphael Rouiller
2024-07-08 14:06:52 +02:00
commit aa54287126
96 changed files with 2718 additions and 0 deletions

41
user/Dockerfile Normal file
View File

@ -0,0 +1,41 @@
FROM python:3.11-slim
# declaration of service variables environment
ENV USER_SERVICE_NAME=${USER_SERVICE_NAME}
ENV DJANGO_SUPERUSER_USERNAME=${DJANGO_SUPERUSER_USERNAME}
ENV DJANGO_SUPERUSER_EMAIL=${DJANGO_SUPERUSER_EMAIL}
ENV DJANGO_SUPERUSER_PASSWORD=${DJANGO_SUPERUSER_PASSWORD}
ARG USER_SERVICE_NAME
ARG DJANGO_SUPERUSER_USERNAME
ARG DJANGO_SUPERUSER_EMAIL
ARG DJANGO_SUPERUSER_PASSWORD
COPY ./user_auth_system /home/archive/${USER_SERVICE_NAME}
# declaration of environment variables for python
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
RUN apt-get clean \
&& apt-get update \
&& apt-get install -y netcat-openbsd \
&& mkdir -p /home/archive
WORKDIR /home/archive
# installation of dependencies
COPY ./conf/requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt \
&& mkdir depedencies && mv requirements.txt depedencies
RUN mkdir -p logs
# copy and execute the initialization script
COPY ./tools/init.sh /home/archive/init.sh
RUN chmod +x /home/archive/init.sh
# set the final working directory
WORKDIR /home/archive/${USER_SERVICE_NAME}
# Command for running the application
CMD ["/bin/bash", "/home/archive/init.sh"]

View File

@ -0,0 +1,9 @@
Django>=4.2,<4.3
djangorestframework==3.14.0
psycopg2-binary>=2.9,<3.0
django-environ==0.10.0
django-cors-headers==4.0.0
pillow==9.5.0
djangorestframework-simplejwt==5.2.2
pyotp==2.8.0
qrcode==7.4.2

21
user/tools/init.sh Normal file
View File

@ -0,0 +1,21 @@
#!/bin/bash
LOGFILE="/home/archive/logs/setup.log"
echo "initialisation of Django" >> $LOGFILE
echo "backend name: $USER_SERVICE_NAME" >> $LOGFILE
echo "initialisation of the project $USER_SERVICE_NAME" >> $LOGFILE
echo "initialisation of Django done" >> $LOGFILE
cd /home/archive/$USER_SERVICE_NAME
echo "Waiting for postgres to get up and running..."
while ! nc -z db_archive 5434; do
echo "waiting for postgress to be listening..."
sleep 1
done
echo "PostgreSQL started"
pip install -U 'Twisted[tls,http2]'
python3 manage.py makemigrations
python3 manage.py migrate
daphne -b 0.0.0.0 -p 8003 user_auth_system.asgi:application

23
user/user_auth_system/manage.py Executable file
View File

@ -0,0 +1,23 @@
#!/usr/bin/env python
"""Django's command-line utility for administrative tasks."""
import os
import sys
from user_auth_system.settings import USER_SERVICE_NAME
def main():
"""Run administrative tasks."""
os.environ.setdefault('DJANGO_SETTINGS_MODULE', f'{USER_SERVICE_NAME}.settings')
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
) from exc
execute_from_command_line(sys.argv)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,130 @@
from pathlib import Path
import os
import environ
env = environ.Env()
ARCHIVE_APP_NAME = env('ARCHIVE_APP_NAME')
DB_NAME = env('DB_ARCHIVE_NAME')
DB_USER = env('DB_ARCHIVE_USER')
DB_PASSWORD = env('DB_ARCHIVE_PASSWORD')
DB_HOST = env('DB_ARCHIVE_HOST')
DB_PORT = env.int('DB_ARCHIVE_PORT')
SECRET_KEY = env('SECRET_KEY_ARCHIVE')
BASE_DIR = Path(__file__).resolve().parent.parent
DEBUG = False
ALLOWED_HOSTS = ['*']
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'daphne',
'rest_framework',
'corsheaders',
f'{ARCHIVE_APP_NAME}.apps.ArchiveAppConfig',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'corsheaders.middleware.CorsMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = f'{ARCHIVE_APP_NAME}.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
DATABASES = {
'default': {
"ENGINE": "django.db.backends.postgresql",
"NAME": DB_NAME,
"USER": DB_USER,
"PASSWORD": DB_PASSWORD,
"HOST": DB_HOST,
"PORT": DB_PORT,
}
}
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
'OPTIONS': {
'min_length': 8,
}
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
{
'NAME': 'your_app.validators.SpecialCharacterValidator',
'OPTIONS': {
'special_chars': '@$!%*?&',
}
},
{
'NAME': 'your_app.validators.UppercaseValidator',
},
{
'NAME': 'your_app.validators.LengthValidator',
'OPTIONS': {
'min_length': 8,
}
},
]
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_TZ = True
STATIC_URL = 'static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'staticfiles')
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
AUTH_USER_MODEL = f'{ARCHIVE_APP_NAME}.CustomUser'
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework_simplejwt.authentication.JWTAuthentication',
],
}
CORS_ALLOW_ALL_ORIGINS = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True

View File

@ -0,0 +1,30 @@
"""
URL configuration for user_auth_system project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include
from django.conf import settings
from django.conf.urls.static import static
from user_auth_system.settings import AUTH_APP_NAME
urlpatterns = [
path('admin/', admin.site.urls),
path('auth/', include(f'{AUTH_APP_NAME}.urls')),
]
if settings.DEBUG:
urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)

View File

@ -0,0 +1,3 @@
from user_auth_system.settings import AUTH_APP_NAME
default_app_config = f'{AUTH_APP_NAME}.apps.AuthUserConfig'

View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

View File

@ -0,0 +1,6 @@
from django.apps import AppConfig
from user_auth_system.settings import AUTH_APP_NAME
class AuthUserConfig(AppConfig):
name = AUTH_APP_NAME
verbose_name = 'Authentication and Authorization'

View File

@ -0,0 +1,4 @@
from .user import CustomUser
from .source import Source
from .tag import Tag
from .suggestion import Suggestion

View File

@ -0,0 +1,17 @@
from django.db import models
from .user import CustomUser
from .tag import Tag
class Source(models.Model):
title = models.CharField(max_length=200)
url = models.URLField()
archived_url = models.URLField()
description = models.TextField()
category = models.CharField(max_length=50)
tags = models.ManyToManyField(Tag, related_name='sources')
added_by = models.ForeignKey(CustomUser, on_delete=models.SET_NULL, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
def __str__(self):
return self.title

View File

@ -0,0 +1,12 @@
from django.db import models
from .user import CustomUser
class Suggestion(models.Model):
url = models.URLField()
description = models.TextField()
suggested_by = models.ForeignKey(CustomUser, on_delete=models.SET_NULL, null=True)
created_at = models.DateTimeField(auto_now_add=True)
is_approved = models.BooleanField(default=False)
def __str__(self):
return self.url

View File

@ -0,0 +1,7 @@
from django.db import models
class Tag(models.Model):
name = models.CharField(max_length=50, unique=True)
def __str__(self):
return self.name

View File

@ -0,0 +1,11 @@
from django.contrib.auth.models import AbstractUser
from django.db import models
class CustomUser(AbstractUser):
profile_picture = models.ImageField(upload_to='profile_pics/', blank=True, null=True)
language = models.CharField(max_length=2, default='en')
is_2fa_enabled = models.BooleanField(default=False)
otp_secret = models.CharField(max_length=32, blank=True)
def __str__(self):
return self.username

View File

@ -0,0 +1,4 @@
from .user import UserSerializer
from .source import SourceSerializer
from .tag import TagSerializer
from .suggestion import SuggestionSerializer

View File

@ -0,0 +1,10 @@
from rest_framework import serializers
from ..models import Source
class SourceSerializer(serializers.ModelSerializer):
tags = serializers.StringRelatedField(many=True)
added_by = serializers.StringRelatedField()
class Meta:
model = Source
fields = ['id', 'title', 'url', 'archived_url', 'description', 'category', 'tags', 'added_by', 'created_at', 'updated_at']

View File

@ -0,0 +1,10 @@
from rest_framework import serializers
from ..models import Suggestion
class SuggestionSerializer(serializers.ModelSerializer):
suggested_by = serializers.StringRelatedField()
class Meta:
model = Suggestion
fields = ['id', 'url', 'description', 'suggested_by', 'created_at', 'is_approved']
read_only_fields = ['suggested_by', 'is_approved']

View File

@ -0,0 +1,7 @@
from rest_framework import serializers
from ..models import Tag
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = Tag
fields = ['id', 'name']

View File

@ -0,0 +1,12 @@
from rest_framework import serializers
from ..models import CustomUser
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = CustomUser
fields = ['id', 'username', 'email', 'profile_picture', 'language', 'is_2fa_enabled']
extra_kwargs = {'password': {'write_only': True}}
def create(self, validated_data):
user = CustomUser.objects.create_user(**validated_data)
return user

View File

@ -0,0 +1,27 @@
from django.urls import reverse
from django.contrib.auth import get_user_model
from rest_framework import status
from rest_framework.test import APITestCase
from your_app.models import Source, Tag
User = get_user_model()
class SearchTests(APITestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='testpassword123')
self.client.force_authenticate(user=self.user)
self.tag = Tag.objects.create(name='covid')
self.source = Source.objects.create(
title='COVID-19 Research',
url='https://example.com/covid',
description='Latest research on COVID-19',
added_by=self.user
)
self.source.tags.add(self.tag)
def test_search_source(self):
url = reverse('search')
response = self.client.get(url, {'q': 'COVID'})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0]['title'], 'COVID-19 Research')

View File

@ -0,0 +1,35 @@
from django.urls import reverse
from django.contrib.auth import get_user_model
from rest_framework import status
from rest_framework.test import APITestCase
from your_app.models import Source, Tag
User = get_user_model()
class SourceTests(APITestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='testpassword123')
self.client.force_authenticate(user=self.user)
self.tag = Tag.objects.create(name='test_tag')
def test_create_source(self):
url = reverse('source-list-create')
data = {
'title': 'Test Source',
'url': 'https://example.com',
'archived_url': 'https://archive.is/example.com',
'description': 'This is a test source',
'category': 'test',
'tags': [self.tag.id]
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Source.objects.count(), 1)
self.assertEqual(Source.objects.get().title, 'Test Source')
def test_list_sources(self):
Source.objects.create(title='Test Source', url='https://example.com', added_by=self.user)
url = reverse('source-list-create')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)

View File

@ -0,0 +1,30 @@
from django.urls import reverse
from django.contrib.auth import get_user_model
from rest_framework import status
from rest_framework.test import APITestCase
from your_app.models import Suggestion
User = get_user_model()
class SuggestionTests(APITestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='testpassword123')
self.client.force_authenticate(user=self.user)
def test_create_suggestion(self):
url = reverse('suggestion-list-create')
data = {
'url': 'https://example.com',
'description': 'This is a test suggestion'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Suggestion.objects.count(), 1)
self.assertEqual(Suggestion.objects.get().url, 'https://example.com')
def test_list_suggestions(self):
Suggestion.objects.create(url='https://example.com', suggested_by=self.user)
url = reverse('suggestion-list-create')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)

View File

@ -0,0 +1,41 @@
from django.urls import reverse
from django.contrib.auth import get_user_model
from rest_framework import status
from rest_framework.test import APITestCase
from rest_framework_simplejwt.tokens import RefreshToken
User = get_user_model()
class UserAuthTests(APITestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='testpassword123', email='test@example.com')
self.user.save()
def test_register_user(self):
url = reverse('user-register')
data = {'username': 'newuser', 'password': 'newpassword123', 'email': 'newuser@example.com'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_login_user(self):
url = reverse('user-login')
data = {'username': 'testuser', 'password': 'testpassword123'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('access', response.data)
self.assertIn('refresh', response.data)
def test_logout_user(self):
refresh = RefreshToken.for_user(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {refresh.access_token}')
url = reverse('user-logout')
data = {'refresh': str(refresh)}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_user_profile(self):
self.client.force_authenticate(user=self.user)
url = reverse('user-profile')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['username'], 'testuser')

View File

@ -0,0 +1,26 @@
from django.urls import path
from .views import (
SourceListCreateView, SourceRetrieveUpdateDestroyView,
TagListCreateView, SuggestionListCreateView, SuggestionApproveView,
UserRegistrationView, UserLoginView, UserLogoutView,
Enable2FAView, Verify2FAView, UserProfileView
)
from rest_framework_simplejwt.views import TokenRefreshView
urlpatterns = [
path('sources/', SourceListCreateView.as_view(), name='source-list-create'),
path('sources/<int:pk>/', SourceRetrieveUpdateDestroyView.as_view(), name='source-detail'),
path('tags/', TagListCreateView.as_view(), name='tag-list-create'),
path('suggestions/', SuggestionListCreateView.as_view(), name='suggestion-list-create'),
path('suggestions/<int:pk>/approve/', SuggestionApproveView.as_view(), name='suggestion-approve'),
path('users/register/', UserRegistrationView.as_view(), name='user-register'),
path('users/login/', UserLoginView.as_view(), name='user-login'),
path('users/logout/', UserLogoutView.as_view(), name='user-logout'),
path('users/profile/', UserProfileView.as_view(), name='user-profile'),
path('users/enable-2fa/', Enable2FAView.as_view(), name='enable-2fa'),
path('users/verify-2fa/', Verify2FAView.as_view(), name='verify-2fa'),
path('token/refresh/', TokenRefreshView.as_view(), name='token-refresh'),
]

View File

@ -0,0 +1,21 @@
from django.core.exceptions import ValidationError
class SpecialCharacterValidator:
def validate(self, password, user=None):
if not any(char in '@$!%*?&' for char in password):
raise ValidationError(
"The password must contain at least one special character (@$!%*?&)."
)
def get_help_text(self):
return "Your password must contain at least one special character (@$!%*?&)."
class UppercaseValidator:
def validate(self, password, user=None):
if not any(char.isupper() for char in password):
raise ValidationError(
"The password must contain at least one uppercase letter."
)
def get_help_text(self):
return "Your password must contain at least one uppercase letter."

View File

@ -0,0 +1,3 @@
from .password_validators import SpecialCharacterValidator, UppercaseValidator, LengthValidator
__all__ = ['SpecialCharacterValidator', 'UppercaseValidator', 'LengthValidator']

View File

@ -0,0 +1,43 @@
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
class SpecialCharacterValidator:
def __init__(self, special_chars='@$!%*?&'):
self.special_chars = special_chars
def validate(self, password, user=None):
if not any(char in self.special_chars for char in password):
raise ValidationError(
_("Le mot de passe doit contenir au moins un caractère spécial (%(special_chars)s)."),
code='password_no_symbol',
params={'special_chars': self.special_chars},
)
def get_help_text(self):
return _(f"Votre mot de passe doit contenir au moins un caractère spécial ({self.special_chars}).")
class UppercaseValidator:
def validate(self, password, user=None):
if not any(char.isupper() for char in password):
raise ValidationError(
_("Le mot de passe doit contenir au moins une lettre majuscule."),
code='password_no_upper',
)
def get_help_text(self):
return _("Votre mot de passe doit contenir au moins une lettre majuscule.")
class LengthValidator:
def __init__(self, min_length=8):
self.min_length = min_length
def validate(self, password, user=None):
if len(password) < self.min_length:
raise ValidationError(
_("Le mot de passe doit contenir au moins %(min_length)d caractères."),
code='password_too_short',
params={'min_length': self.min_length},
)
def get_help_text(self):
return _(f"Votre mot de passe doit contenir au moins {self.min_length} caractères.")

View File

@ -0,0 +1,5 @@
from .user import UserRegistrationView, UserLoginView, UserLogoutView, UserProfileView, Enable2FAView, Verify2FAView
from .source import SourceListCreateView, SourceRetrieveUpdateDestroyView
from .tag import TagListCreateView
from .suggestion import SuggestionListCreateView, SuggestionApproveView
from .search import SearchView

View File

@ -0,0 +1,13 @@
from rest_framework import generics, permissions
from ..models import Source
from ..serializers import SourceSerializer
class SearchView(generics.ListAPIView):
serializer_class = SourceSerializer
permission_classes = [permissions.AllowAny]
def get_queryset(self):
query = self.request.query_params.get('q', '')
return Source.objects.filter(title__icontains=query) | \
Source.objects.filter(description__icontains=query) | \
Source.objects.filter(tags__name__icontains=query).distinct()

View File

@ -0,0 +1,16 @@
from rest_framework import generics, permissions
from ..models import Source
from ..serializers import SourceSerializer
class SourceListCreateView(generics.ListCreateAPIView):
queryset = Source.objects.all()
serializer_class = SourceSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def perform_create(self, serializer):
serializer.save(added_by=self.request.user)
class SourceRetrieveUpdateDestroyView(generics.RetrieveUpdateDestroyAPIView):
queryset = Source.objects.all()
serializer_class = SourceSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]

View File

@ -0,0 +1,23 @@
from rest_framework import generics, permissions, status
from rest_framework.response import Response
from ..models import Suggestion
from ..serializers import SuggestionSerializer
class SuggestionListCreateView(generics.ListCreateAPIView):
queryset = Suggestion.objects.all()
serializer_class = SuggestionSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def perform_create(self, serializer):
serializer.save(suggested_by=self.request.user)
class SuggestionApproveView(generics.UpdateAPIView):
queryset = Suggestion.objects.all()
serializer_class = SuggestionSerializer
permission_classes = [permissions.IsAdminUser]
def update(self, request, *args, **kwargs):
instance = self.get_object()
instance.is_approved = True
instance.save()
return Response({"message": "Suggestion approved"})

View File

@ -0,0 +1,8 @@
from rest_framework import generics, permissions
from ..models import Tag
from ..serializers import TagSerializer
class TagListCreateView(generics.ListCreateAPIView):
queryset = Tag.objects.all()
serializer_class = TagSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]

View File

@ -0,0 +1,100 @@
from rest_framework import generics, permissions, status
from rest_framework.response import Response
from rest_framework.views import APIView
from ..models import CustomUser
from ..serializers import UserSerializer
from django.contrib.auth import authenticate
from rest_framework_simplejwt.tokens import RefreshToken
import pyotp
import qrcode
import base64
from io import BytesIO
class UserRegistrationView(generics.CreateAPIView):
queryset = CustomUser.objects.all()
serializer_class = UserSerializer
permission_classes = [permissions.AllowAny]
class UserLoginView(APIView):
permission_classes = [permissions.AllowAny]
def post(self, request):
username = request.data.get('username')
password = request.data.get('password')
user = authenticate(username=username, password=password)
if user:
if user.is_2fa_enabled:
return Response({"message": "2FA is enabled. Please provide OTP.", "require_2fa": True, "user_id": user.id})
refresh = RefreshToken.for_user(user)
return Response({
'refresh': str(refresh),
'access': str(refresh.access_token),
'user': UserSerializer(user).data
})
return Response({"error": "Invalid Credentials"}, status=status.HTTP_400_BAD_REQUEST)
class UserLogoutView(APIView):
permission_classes = [permissions.IsAuthenticated]
def post(self, request):
try:
refresh_token = request.data["refresh_token"]
token = RefreshToken(refresh_token)
token.blacklist()
return Response({"message": "Successfully Logged out."}, status=status.HTTP_200_OK)
except Exception as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
class UserProfileView(generics.RetrieveUpdateAPIView):
serializer_class = UserSerializer
permission_classes = [permissions.IsAuthenticated]
def get_object(self):
return self.request.user
class Enable2FAView(APIView):
permission_classes = [permissions.IsAuthenticated]
def post(self, request):
user = request.user
secret_key = pyotp.random_base32()
user.otp_secret = secret_key
user.save()
totp = pyotp.TOTP(secret_key)
uri = totp.provisioning_uri(name=user.email, issuer_name="ArchiveApp")
qr = qrcode.QRCode(version=1, box_size=10, border=5)
qr.add_data(uri)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return Response({
'secret_key': secret_key,
'qr_code': f"data:image/png;base64,{img_str}"
})
class Verify2FAView(APIView):
permission_classes = [permissions.IsAuthenticated]
def post(self, request):
user = request.user
otp = request.data.get('otp')
totp = pyotp.TOTP(user.otp_secret)
if totp.verify(otp):
user.is_2fa_enabled = True
user.save()
refresh = RefreshToken.for_user(user)
return Response({
"message": "2FA verified successfully",
'refresh': str(refresh),
'access': str(refresh.access_token),
'user': UserSerializer(user).data
})
return Response({"error": "Invalid OTP"}, status=status.HTTP_400_BAD_REQUEST)