117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
|
from rest_framework import routers, viewsets, serializers, permissions
|
||
|
from rest_framework.decorators import api_view, permission_classes, authentication_classes
|
||
|
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||
|
from rest_framework.permissions import IsAuthenticated
|
||
|
from rest_framework.response import Response
|
||
|
from django.contrib.auth import login
|
||
|
from django.urls import path
|
||
|
from django.dispatch import receiver
|
||
|
from django.db.models.signals import post_save
|
||
|
from django.contrib.auth.models import Group
|
||
|
from knox.models import AuthToken
|
||
|
from knox.views import LoginView as KnoxLoginView
|
||
|
|
||
|
from authentication.models import ExtendedUser
|
||
|
|
||
|
|
||
|
class UserSerializer(serializers.ModelSerializer):
|
||
|
permissions = serializers.SerializerMethodField()
|
||
|
groups = serializers.SlugRelatedField(many=True, read_only=True, slug_field='name')
|
||
|
|
||
|
class Meta:
|
||
|
model = ExtendedUser
|
||
|
fields = ('id', 'username', 'email', 'first_name', 'last_name', 'permissions', 'groups')
|
||
|
read_only_fields = ('id', 'username', 'email', 'first_name', 'last_name', 'permissions', 'groups')
|
||
|
|
||
|
def get_permissions(self, obj):
|
||
|
return list(set(obj.get_permissions()))
|
||
|
|
||
|
|
||
|
@receiver(post_save, sender=ExtendedUser)
|
||
|
def create_auth_token(sender, instance=None, created=False, **kwargs):
|
||
|
if created:
|
||
|
AuthToken.objects.create(user=instance)
|
||
|
|
||
|
|
||
|
class UserViewSet(viewsets.ModelViewSet):
|
||
|
queryset = ExtendedUser.objects.all()
|
||
|
serializer_class = UserSerializer
|
||
|
|
||
|
|
||
|
class GroupSerializer(serializers.ModelSerializer):
|
||
|
permissions = serializers.SerializerMethodField()
|
||
|
members = serializers.SerializerMethodField()
|
||
|
|
||
|
class Meta:
|
||
|
model = Group
|
||
|
fields = ('id', 'name', 'permissions', 'members')
|
||
|
|
||
|
def get_permissions(self, obj):
|
||
|
return ["*:" + p.codename for p in obj.permissions.all()]
|
||
|
|
||
|
def get_members(self, obj):
|
||
|
return [u.username for u in obj.user_set.all()]
|
||
|
|
||
|
|
||
|
class GroupViewSet(viewsets.ModelViewSet):
|
||
|
queryset = Group.objects.all()
|
||
|
serializer_class = GroupSerializer
|
||
|
|
||
|
|
||
|
@api_view(['GET'])
|
||
|
@permission_classes([IsAuthenticated])
|
||
|
def selfUser(request):
|
||
|
serializer = UserSerializer(request.user)
|
||
|
return Response(serializer.data, status=200)
|
||
|
|
||
|
|
||
|
@api_view(['POST'])
|
||
|
@permission_classes([])
|
||
|
@authentication_classes([])
|
||
|
def registerUser(request):
|
||
|
try:
|
||
|
username = request.data.get('username')
|
||
|
password = request.data.get('password')
|
||
|
email = request.data.get('email')
|
||
|
|
||
|
errors = {}
|
||
|
if not username:
|
||
|
errors['username'] = 'Username is required'
|
||
|
if not password:
|
||
|
errors['password'] = 'Password is required'
|
||
|
if not email:
|
||
|
errors['email'] = 'Email is required'
|
||
|
if ExtendedUser.objects.filter(email=email).exists():
|
||
|
errors['email'] = 'Email already exists'
|
||
|
if ExtendedUser.objects.filter(username=username).exists():
|
||
|
errors['username'] = 'Username already exists'
|
||
|
if errors:
|
||
|
return Response({'errors': errors}, status=400)
|
||
|
user = ExtendedUser.objects.create_user(username, email, password)
|
||
|
return Response({'username': user.username, 'email': user.email}, status=201)
|
||
|
except Exception as e:
|
||
|
return Response({'errors': str(e)}, status=400)
|
||
|
|
||
|
|
||
|
class LoginView(KnoxLoginView):
|
||
|
permission_classes = (permissions.AllowAny,)
|
||
|
authentication_classes = ()
|
||
|
|
||
|
def post(self, request, format=None):
|
||
|
serializer = AuthTokenSerializer(data=request.data)
|
||
|
serializer.is_valid(raise_exception=True)
|
||
|
user = serializer.validated_data['user']
|
||
|
login(request, user)
|
||
|
return super(LoginView, self).post(request, format=None)
|
||
|
|
||
|
|
||
|
router = routers.SimpleRouter()
|
||
|
router.register(r'users', UserViewSet, basename='users')
|
||
|
router.register(r'groups', GroupViewSet, basename='groups')
|
||
|
|
||
|
urlpatterns = router.urls + [
|
||
|
path('self/', selfUser),
|
||
|
path('login/', LoginView.as_view()),
|
||
|
path('register/', registerUser),
|
||
|
]
|