1
# -*- coding utf-8 -*-
2
# classes/models/client.py
3
# class:: Client
4
5
from datetime import datetime, timedelta
6
from flask import current_app
7
8
from swtstore.classes.database import db
9
from swtstore.classes.models.um import User
10
from swtstore.classes import oauth
11
12
class Client(db.Model):
13
    """
14
    The third-party application registering with the platform
15
    """
16
17
    __tablename__ = 'clients'
18
19
    id = db.Column(db.String(40), primary_key=True)
20
21
    client_secret = db.Column(db.String(55), nullable=False)
22
23
    name = db.Column(db.String(60), nullable=False)
24
25
    description = db.Column(db.String(400))
26
27
    # creator of the client application
28
    user_id = db.Column(db.ForeignKey('users.id'))
29
    creator = db.relationship('User')
30
31
    _is_private = db.Column(db.Boolean)
32
33
    _host_url = db.Column(db.String(60))
34
35
    _redirect_uris = db.Column(db.Text)
36
    _default_scopes = db.Column(db.Text)
37
38
39
    @property
40
    def client_id(self):
41
        return self.id
42
43
    @property
44
    def client_type(self):
45
        if self._is_private:
46
            return 'private'
47
        return 'public'
48
49
    @property
50
    def host_url(self):
51
        return self._host_url
52
53
    @property
54
    def redirect_uris(self):
55
        if self._redirect_uris:
56
            return self._redirect_uris.split()
57
        return []
58
59
    @property
60
    def default_redirect_uri(self):
61
        return self.redirect_uris[0]
62
63
    @property
64
    def default_scopes(self):
65
        if self._default_scopes:
66
            return self._default_scopes.split()
67
        return []
68
69
    def __repr__(self):
70
        return '<Client: %s :: ID: %s>' % (self.name, self.id)
71
72
    def __str__(self):
73
        return '<Client: %s :: ID: %s>' % (self.name, self.id)
74
75
76
    # create and persist the client to the database
77
    def persist(self):
78
        db.session.add(self)
79
        db.session.commit()
80
81
    @staticmethod
82
    def getClientsByCreator(user_id):
83
        clients = Client.query.filter_by(user_id=user_id)
84
        return [each for each in clients]
85
86
87
class Grant(db.Model):
88
    """
89
    A grant token is created in the authorization flow, and will be
90
    destroyed when the authorization finished. In this case, it would be better
91
    to store the data in a cache, which would benefit a better performance.
92
    """
93
    #TODO: this would perform better if its only in the cache. and not in a db.
94
95
    __tablename__ = 'grants'
96
97
    id = db.Column(db.Integer, primary_key=True)
98
    user_id = db.Column(db.Integer, db.ForeignKey('users.id',
99
                                                  ondelete='CASCADE'))
100
    user = db.relationship('User')
101
102
    client_id = db.Column(db.String(40), db.ForeignKey('clients.id'),
103
                          nullable=False)
104
    client = db.relationship('Client')
105
106
    code = db.Column(db.String(255), index=True, nullable=False)
107
108
    redirect_uri = db.Column(db.String(255))
109
    expires = db.Column(db.DateTime)
110
111
    _scopes = db.Column(db.Text)
112
113
    @property
114
    def scopes(self):
115
        if self._scopes:
116
            return self._scopes.split()
117
        return []
118
119
    def delete(self):
120
        db.session.delete(self)
121
        db.session.commit()
122
123
124
class Token(db.Model):
125
    """
126
    The final token to be used by a client
127
    """
128
129
    __tablename__ = 'tokens'
130
131
    id = db.Column(db.Integer, primary_key=True)
132
133
    client_id = db.Column(db.String(40), db.ForeignKey('clients.id'),
134
                          nullable=False)
135
    client = db.relationship('Client')
136
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'))
137
    user = db.relationship('User')
138
139
    token_type = db.Column(db.String(40))
140
141
    access_token = db.Column(db.String(255), unique=True)
142
    refresh_token = db.Column(db.String(255), unique=True)
143
    expires = db.Column(db.DateTime)
144
    _scopes = db.Column(db.Text)
145
146
    @property
147
    def scopes(self):
148
        if self._scopes:
149
            return self._scopes.split()
150
        return []
151
152
153
154
#TODO: find out how to better structure the following code
155
156
# OAuthLib decorators used by OAuthLib in the OAuth flow
157
158
@oauth.clientgetter
159
def loadClient(client_id):
160
    current_app.logger.debug('@oauth.clientgetter')
161
    #return Client.query.filter_by(id=client_id).first()
162
    return Client.query.get(client_id)
163
164
@oauth.grantgetter
165
def loadGrant(client_id, code):
166
    current_app.logger.debug('@oauth.grantgetter')
167
    return Grant.query.filter_by(client_id=client_id, code=code).first()
168
169
@oauth.grantsetter
170
def saveGrant(client_id, code, request, *args, **kwargs):
171
    current_app.logger.debug('@oauth.grantsetter')
172
    expires = datetime.utcnow() + timedelta(seconds=100)
173
    grant = Grant(
174
        client_id = client_id,
175
        code = code['code'],
176
        redirect_uri = request.redirect_uri,
177
        _scopes = ' '.join(request.scopes),
178
        user = User.getCurrentUser(),
179
        expires = expires
180
    )
181
    db.session.add(grant)
182
    db.session.commit()
183
    return grant
184
185
@oauth.tokengetter
186
def loadToken(access_token=None, refresh_token=None):
187
    current_app.logger.debug('@oauth.tokengetter')
188
    if access_token:
189
        return Token.query.filter_by(access_token=access_token).first()
190
    elif refresh_token:
191
        return Token.query.filter_by(refresh_token=refresh_token).first()
192
193
@oauth.tokensetter
194
def saveToken(token, request, *args, **kwargs):
195
    current_app.logger.debug('@oauth.tokensetter')
196
197
    toks = Token.query.filter_by(client_id=request.client.id,
198
                                 user_id=request.user.id)
199
    # make sure that every client has only one token connected to a user
200
    for t in toks:
201
        db.session.delete(t)
202
203
    expires_in = token.pop('expires_in')
204
    expires = datetime.utcnow() + timedelta(seconds=expires_in)
205
206
    tok = Token(
207
        access_token = token['access_token'],
208
        refresh_token = token['refresh_token'],
209
        token_type = token['token_type'],
210
        _scopes = token['scope'],
211
        expires = expires,
212
        client_id = request.client.id,
213
        user = request.user
214
    )
215
    db.session.add(tok)
216
    db.session.commit()
217
    return tok
218
219
@oauth.usergetter
220
def getUser():
221
    return User.getCurrentUser()