# By: Riasat Ullah
# This file contains functions that help to create tokens that can be used to identify
# a user every time they send a request to the Rest Api and the Incident API.

from cryptography.fernet import Fernet
from exceptions.user_exceptions import UnauthorizedRequest
from taskcallrest import settings
from utils import constants, errors, file_storage, s3, times, var_names
import configuration
import datetime
import jwt
import random


token_bucket = 'taskcall-prod-data'
token_key = 'credentials/token_key.json'
algo_type = 'HS256'


#########################
# General functions
#########################

def get_token_keys():
    '''
    Gets the stored secret key.
    :return: (tuple) access key (str), secret key (dict of str), fernet key (binary str)
    :errors: KeyError, OSError, IOError
    '''
    try:
        data = s3.read_json(token_bucket, token_key)
        reg = settings.REGION
        return data[reg]['access_key'], data[reg]['refresh_key'],\
            data[reg]['secret_keys'], data[reg]['fernet_key'].encode()
    except KeyError as e:
            err = 'Could not find one of the specified keys. ' + '\n' + str(e)
            raise KeyError(err)
    except (OSError, IOError) as e:
        err = 'Could not read token key file' + '\n' + str(e)
        raise OSError(err)


def get_randomized_secret_key_version():
    '''
    Randomly get a version out of the available versions of secret keys to use.
    :return: (int) version number of secret key
    '''
    versions = list(secret_keys.keys())
    return versions[random.randint(0, len(versions) - 1)]


def new_access_token_expiry_date(org_sub_end: datetime.datetime):
    '''
    Gets the expiry date that should be set for a new access token.
    This will return the earliest date amongst 15 days from now, last day of month and trial end.
    :param org_sub_end: date and time the organization's current subscription will end on
    :return: (datetime.datetime) -> expiry date
    '''
    max_allowed_date = times.get_current_timestamp() + datetime.timedelta(days=15)
    month_end_date = times.get_last_datetime_of_current_month()
    return min([max_allowed_date, month_end_date, org_sub_end])


def new_refresh_token_expiry_date(access_method: str):
    '''
    Get the expiry date that should be set for a new refresh token.
    :param access_method: APP or WEBSITE
    :return: (datetime.datetime) -> expiry date
    '''
    current_time = times.get_current_timestamp()
    if access_method == constants.app:
        return current_time + datetime.timedelta(days=365 * 4)
    else:
        return current_time + datetime.timedelta(days=30 * 3)


def get_token_time_string():
    '''
    Get the time this function is being called at in a sticky format. This is needed in tokens
    for randomization. Otherwise, the token for a specific set of variables will always look the same.
    :return: (str) timestamp string
    '''
    return times.get_current_timestamp().strftime('%Y%m%d%H%M%S%f')


#############################################
# Access token functions - Mobile/Web App
#############################################

def create_token(user_id: int, org_id: int, user_permissions: str, org_permissions: str,
                 expiry_date: datetime.datetime):
    '''
    Creates a JWT token to be used as access token for Mobile and Web apps.
    :param user_id: user_id of the user
    :param org_id: the organization ID
    :param user_permissions: binary string of permissions combination
    :param org_permissions: user_id of the user
    :param expiry_date: the last day the org permissions would be valid till
    :return: JWT token
    '''
    payload = [user_id, org_id, user_permissions, org_permissions, get_token_time_string(), settings.REGION]
    token = jwt.encode({var_names.payload: payload, var_names.exp: expiry_date}, access_key, algorithm=algo_type)
    return token


def read_token(token: str, region_agnostic=False):
    '''
    Read a token and retrieve its hidden values.
    :param token: (str) the token to read
    :param region_agnostic: True if the token should be read without taking the region into consideration;
        this should be used with extreme caution. This parameter has only been introduced to handle integrations
        like Microsoft Teams that require verification tokens to be set only in the Europe region and requires cross
        region api calls.
    :return: (list) -> [user_id, organization ID, user permissions, organization permissions]
    :errors: jwt.ExpiredSignatureError, jwt.InvalidSignatureError
    '''
    try:
        token_layer = jwt.decode(token, access_key, algorithms=algo_type)
        token = token_layer[var_names.payload]

        if len(token) == 5:
            if settings.REGION != constants.aws_europe_paris:
                raise UnauthorizedRequest(errors.err_internal_token_region_invalid)
        else:
            # remove the region and verify that it is the same as the current region
            if token.pop() != settings.REGION and not region_agnostic:
                raise UnauthorizedRequest(errors.err_internal_token_region_invalid)

        # remove the token time string
        token.pop()

        return token
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


def extract_token(request):
    '''
    Extracts token from a http request header.
    :param request: Http request
    :return: (str) -> token
    '''
    if constants.authorization_attribute in request.headers:
        header = request.headers.get(constants.authorization_attribute).split(' ')
        if len(header) == 2:
            if header[0] == var_names.token:
                token = header[1]
                return token
    raise UnauthorizedRequest(errors.err_authorization)


def authorize_request(request, region_agnostic=False):
    '''
    Checks if a request is valid by ensuring that the request header contains an authorized token.
    If the request is valid then extract the details of the token and send it.
    :param request: Http request
    :param region_agnostic: True if the token should be read without taking the region into consideration;
        this should be used with extreme caution. This parameter has only been introduced to handle integrations
        like Microsoft Teams that require verification tokens to be set only in the Europe region and requires cross
        region api calls.
    :return: (list) -> [user_id, organization_id, user permissions, organization permissions]
    :errors: jwt.ExpireSignatureError, jwt.InvalidSignature
    '''
    try:
        token = extract_token(request)
        token_details = read_token(token, region_agnostic=region_agnostic)
        return token_details
    except (jwt.ExpiredSignatureError, jwt.InvalidSignatureError):
        raise


def get_region(token: str):
    '''
    Get the region a given token is for.
    :param token: (str) the token to read
    :return: (str) -> region
    :errors: jwt.ExpiredSignatureError, jwt.InvalidSignatureError
    '''
    try:
        token_layer = jwt.decode(token, access_key, algorithms=algo_type)
        token = token_layer[var_names.payload]

        if len(token) == 5:
            return constants.aws_europe_paris
        else:
            return token.pop()
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


#############################################
# Refresh token functions - Mobile/Web App
#############################################

def create_refresh_token(user_id: int, org_id: int, access_method: str, token_end: datetime.datetime):
    '''
    Creates a refresh token that can be used with access tokens that are created above.
    The refresh token should be saved in the database.
    :param user_id: user_id of the user this refresh token is for
    :param org_id: ID of the organization the user is in
    :param access_method: APP or WEBSITE
    :param token_end: (datetime.datetime) date and time the token should expire on
    :return: (str) refresh token
    '''
    payload = [user_id, org_id, access_method, get_token_time_string(), settings.REGION]
    version = get_randomized_secret_key_version()
    token_layer = jwt.encode({var_names.payload: payload}, secret_keys[version], algorithm=algo_type)
    refresh_token = jwt.encode(
        {var_names.payload: token_layer, var_names.version: version, var_names.exp: token_end},
        refresh_key,
        algorithm=algo_type
    )
    return refresh_token


def read_refresh_token(token: str):
    '''
    Read a refresh token and retrieve its hidden values.
    :param token: (str) the token to read
    :return: (list) -> [user_id, org ID, access method]
    :errors: jwt.InvalidSignatureError, jwt.DecodeError KeyError
    '''
    try:
        access_layer = jwt.decode(token, refresh_key, algorithms=algo_type)
        access_layer_payload = access_layer[var_names.payload].encode()
        version = access_layer[var_names.version]
        token_layer = jwt.decode(access_layer_payload, secret_keys[version], algorithms=algo_type)
        token = token_layer[var_names.payload]

        if len(token) == 4:
            if settings.REGION != constants.aws_europe_paris:
                raise UnauthorizedRequest(errors.err_internal_token_region_invalid)
        else:
            # remove the region and verify that it is the same as the current region
            if token.pop() != settings.REGION:
                raise UnauthorizedRequest(errors.err_internal_token_region_invalid)

        token.pop()
        return token
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


#############################################
# API Key functions - Incidents API
#############################################

def extract_api_key(request):
    '''
    Extracts api key from a http request header.
    :param request: Http request
    :return: (str) -> token
    '''
    if constants.authorization_attribute in request.headers:
        header = request.headers.get(constants.authorization_attribute).split(' ')
        if len(header) == 2:
            if header[0] == var_names.token:
                api_key = header[1]
                if len(api_key) != configuration.api_key_length[settings.REGION]:
                    raise UnauthorizedRequest(errors.err_authorization)
                else:
                    return api_key
    raise UnauthorizedRequest(errors.err_authorization)


#############################################
# Registration token functions
#############################################

def create_registration_token(requested_email, ip_address, expiry_date):
    '''
    Create a new registration token.
    :param requested_email: email address of the new user being registered
    :param ip_address: IP address where the request is being received from
    :param expiry_date: time when the token must expire
    :return: (str) registration token
    '''
    payload = [requested_email, ip_address, get_token_time_string(), settings.REGION]
    version = get_randomized_secret_key_version()

    first_layer = jwt.encode({var_names.payload: payload}, secret_keys[version], algorithm=algo_type)
    second_layer = jwt.encode({var_names.payload: first_layer, var_names.version: version, var_names.exp: expiry_date},
                              access_key,
                              algorithm=algo_type)
    cipher = Fernet(fernet_key)
    third_layer = cipher.encrypt(bytes(second_layer, 'utf-8'))
    return third_layer.decode()


def read_registration_token(registration_token: str):
    '''
    Read a registration token.
    :param registration_token: (str) the registration token
    :return: (list) -> [requested email, ip address]
    :errors: jwt.ExpiredSignatureError, jwt.InvalidSignatureError, jwt.DecodeError KeyError
    '''
    try:
        cipher = Fernet(fernet_key)
        decrypted_key = cipher.decrypt(registration_token.encode())

        access_layer = jwt.decode(decrypted_key, access_key, algorithms=algo_type)
        access_layer_payload = access_layer[var_names.payload]
        version = access_layer[var_names.version]

        secret_layer = jwt.decode(access_layer_payload, secret_keys[version], algorithms=algo_type)
        payload = secret_layer[var_names.payload]

        # Remove the region and verify that it is the same as the current region. Then remove the timestamp.
        if payload.pop() != settings.REGION:
            raise UnauthorizedRequest(errors.err_internal_token_region_invalid)
        payload.pop()

        return payload
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


#####################################
# Password Reset Token
#####################################

def create_password_reset_token(user_id, token_perm, expiry_date):
    '''
    Create a new registration token.
    :param user_id: ID of the user who the request is for
    :param token_perm: the permissions the token has
    :param expiry_date: time when the token must expire
    :return: (str) registration token
    '''
    payload = [user_id, token_perm, get_token_time_string(), settings.REGION]
    version = get_randomized_secret_key_version()

    first_layer = jwt.encode({var_names.payload: payload}, secret_keys[version], algorithm=algo_type)
    second_layer = jwt.encode({var_names.payload: first_layer, var_names.version: version, var_names.exp: expiry_date},
                              access_key,
                              algorithm=algo_type)
    cipher = Fernet(fernet_key)
    third_layer = cipher.encrypt(bytes(second_layer, 'utf-8'))
    return third_layer.decode()


def read_password_reset_token(password_reset_token: str):
    '''
    Read a registration token.
    :param password_reset_token: (str) the password reset token
    :return: (list) -> [user_id, token_perm]
    :errors: jwt.ExpiredSignatureError, jwt.InvalidSignatureError, jwt.DecodeError KeyError
    '''
    try:
        cipher = Fernet(fernet_key)
        decrypted_key = cipher.decrypt(password_reset_token.encode())

        access_layer = jwt.decode(decrypted_key, access_key, algorithms=algo_type)
        access_layer_payload = access_layer[var_names.payload]
        version = access_layer[var_names.version]

        secret_layer = jwt.decode(access_layer_payload, secret_keys[version], algorithms=algo_type)
        payload = secret_layer[var_names.payload]

        # Remove the region and verify that it is the same as the current region. Then remove the timestamp.
        if payload.pop() != settings.REGION:
            raise UnauthorizedRequest(errors.err_internal_token_region_invalid)
        payload.pop()

        return payload
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


#####################################################################
# Public Request Verifiers (requests originate from the web app)
#####################################################################
def verify_public_request(request):
    '''
    Verifies public requests (originally originating from outside TaskCall), reads the verification token
    sent by the web application and retrieve its hidden values.
    :param request: HttpRequest
    :return: (list) -> [subdomain, account_number]
    :errors: jwt.ExpiredSignatureError, jwt.InvalidSignatureError
    '''
    try:
        token = extract_token(request)
        verifier_key = s3.read_json(file_storage.S3_BUCKET_TASKCALL_PROD_DATA,
                                    file_storage.S3_KEY_PUBLIC_REQUEST_VERIFIER)[var_names.access_key]
        token_layer = jwt.decode(token, verifier_key, algorithms='HS256')
        token = token_layer[var_names.payload]
        return token
    except jwt.ExpiredSignatureError:
        raise
    except (jwt.InvalidSignatureError, jwt.DecodeError, KeyError):
        raise jwt.InvalidSignatureError


if settings.INITIALIZE_GLOBAL_VARIABLES:
    access_key, refresh_key, secret_keys, fernet_key = get_token_keys()
else:
    access_key, refresh_key, secret_keys, fernet_key = (None, None, None, None)
