from fastapi import APIRouter, HTTPException, Request , status 
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2AuthorizationCodeBearer
from jose import jwt
from typing import Dict
import httpx
import requests
from app.global_constants import GlobalConstants
from app.config import Config
import requests
from app.models.simple_model import RefreshTokenSchema
from app.utils.exception import OAuthException
import json
from app.config import Config
from cryptography.hazmat.primitives import serialization
from functools import wraps
import jwt


# Azure AD Configurations
TENANT_ID = Config.TENANT_ID 
CLIENT_ID = Config.CLIENT_ID  
CLIENT_SECRET = Config.CLIENT_SECRET
REDIRECT_URI = Config.REDIRECT_URI
AUTH_URL = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/authorize"
TOKEN_URL = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/token"
KEYS_URL=f"https://login.microsoftonline.com/{TENANT_ID}/discovery/keys"
SCOPE=f"api://{CLIENT_ID}/access_as_user offline_access"
POST_LOGOUT_REDIRECT_URI = Config.POST_LOGOUT_REDIRECT_URI

   
# Validate Token using Azure AD Public Keys
def validate_azure_ad_token(token: str):
    try:
        
        if not token:
            raise OAuthException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization token missing or invalid")

        response1 = requests.get(KEYS_URL)
        response1.raise_for_status()
        keys = response1.json().get('keys', [])

        token_headers = jwt.get_unverified_header(token)
        token_kid = token_headers.get('kid')
        public_key = next((key for key in keys if key.get('kid') == token_kid), None)

        if not public_key:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Public key not found")

        rsa_pem_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(public_key))
        rsa_pem_key_bytes = rsa_pem_key.public_bytes(
            encoding=serialization.Encoding.PEM, 
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

        decoded_token = jwt.decode(
            token, 
            key=rsa_pem_key_bytes,
            algorithms=['RS256'],
            audience=f"api://{CLIENT_ID}",
            issuer=f"https://sts.windows.net/{TENANT_ID}/"
        )
        
        return json.dumps(decoded_token, indent=2)
    except requests.RequestException as e:
        raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Request error: {str(e)}")
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired")
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {str(e)}")
    
    
   
def  validate_and_authorize():
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            try:
                                
                decoded_token = validate_azure_ad_token(kwargs.get("token", None))
                request: Request = kwargs.get("request")
                
                                
                # Convert user info into dictionary (assuming it's in string format)
                user_info = json.loads(decoded_token)
                name = user_info.get('name')
                unique_name = user_info.get('unique_name')
                
        
                # Check if the user has access to the requested API
                requested_url = request.url.path
                requested_method = request.method


                allowed_paths = [
                    {"endpoint": "/simple-models", "method": "GET"},
                ]
                
                
                has_access = any(
                    requested_url.startswith(path["endpoint"]) and path["method"] == requested_method
                    for path in allowed_paths
                )
                
                if not has_access:
                    raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=GlobalConstants.api_response_messages.forbidden)

                
                # Define the allowed keys
                allowed_keys = {"name", "user_name", "roles"}

                # Retain only the allowed keys in user_info
                user_info = {key: user_info[key] for key in allowed_keys if key in user_info}

                # Update user_info with new values
                user_info.update({
                    "name": name,
                    "user_name": unique_name
                })
                

                # Store in request state
                request.state.user_info = user_info
                
                # Proceed to the actual route function
                return func(*args, **kwargs)
            
        
            except HTTPException as e:
                return GlobalConstants.return_api_response(
                    message=e.detail,
                    result=None,
                    status_code=e.status_code
                )  
            except OAuthException as e:
                return GlobalConstants.return_api_response(
                    message=e.detail,
                    result=None,
                    status_code=e.status_code
                ) 
            except requests.RequestException as e:
                return GlobalConstants.return_api_response(
                    message=e.detail,
                    result=None,
                    status_code=e.status_code
                )
            except jwt.ExpiredSignatureError:
                return GlobalConstants.return_api_response(
                    message=e.detail,
                    result=None,
                    status_code=e.status_code
                )
            except Exception as e:
                return GlobalConstants.return_api_response(
                    message=e.detail,
                    result=None,
                    status_code=e.status_code
                )

        return wrapper
    return decorator 
   




class AuthRoutes:
    router = APIRouter()
    

    oauth2_scheme = OAuth2AuthorizationCodeBearer(
        authorizationUrl=AUTH_URL,
        tokenUrl=TOKEN_URL,
    )

    tokens: Dict[str, str] = {}
    


    """
    Initiates the Microsoft login process.
    """
    @staticmethod
    @router.get("/login")
    def login():
        auth_endpoint = (
            f"{AUTH_URL}?"
            f"client_id={CLIENT_ID}&response_type=code&redirect_uri={REDIRECT_URI}"
            f"&response_mode=query&scope={SCOPE}"
        )
        return RedirectResponse(url=auth_endpoint)
    

    """
    Callback route to exchange authorization code for tokens.
    """
    @staticmethod
    @router.get("/authenticate")
    async def callback(request: Request):
        code = request.query_params.get("code")
        if not code:
            error = request.query_params.get("error")
            error_description = request.query_params.get("error_description")
            return GlobalConstants.return_api_response(
                message=f"Authentication failed: {error} - {error_description}",
                result=[],
                status_code=status.HTTP_400_BAD_REQUEST
            )

        token_data = {
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": REDIRECT_URI
        }
        

        try:
            response = requests.post(TOKEN_URL, data=token_data)
            response.raise_for_status()  # Raises an error for non-200 responses

            token_json = response.json()
            access_token = token_json.get("access_token")
            refresh_token = token_json.get("refresh_token")

            if not access_token:
                return GlobalConstants.return_api_response(
                    message="Access token not found in the response",
                    result=[],
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
                )


            return GlobalConstants.return_api_response(
                message=GlobalConstants.api_response_messages.success,
                result=[{
                    "access_token": access_token,
                    "refresh_token": refresh_token,
                }],
                status_code=status.HTTP_200_OK
            )

        except requests.exceptions.Timeout:
            print("Error in Timeout : ", str(e))
            return GlobalConstants.return_api_response(
                message="Token exchange service timeout",
                result=None,
                status_code=status.HTTP_504_GATEWAY_TIMEOUT
            )

        except requests.exceptions.ConnectionError:
            print("Error in ConnectionError : ", str(e))
            return GlobalConstants.return_api_response(
                message="Failed to connect to token exchange service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except requests.RequestException as e:
            print("Error in RequestException : ", str(e))
            return GlobalConstants.return_api_response(
                message="Error communicating with token exchange service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except Exception as e:
            print("Error in Exception : ", str(e))
            return GlobalConstants.return_api_response(
                message="An unexpected error occurred",
                result=None,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
            )



    """
    Refreshes the access token using the refresh token.
    """
    @staticmethod
    @router.post("/refresh")
    async def refresh_token(refresh_token: RefreshTokenSchema):
        
        refresh_token = refresh_token.refresh_token
        if not refresh_token:
            return GlobalConstants.return_api_response(
                message="Refresh token not available",
                result=[],
                status_code=status.HTTP_400_BAD_REQUEST
            )

        refresh_data = {
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            "grant_type": "refresh_token",
            "refresh_token": refresh_token,
            "scope":f"api://{CLIENT_ID}/access_as_user offline_access"
        }

        try:
            async with httpx.AsyncClient(timeout=30) as client: 
                refresh_response = await client.post(TOKEN_URL, data=refresh_data)
                refresh_response.raise_for_status()  # Raises an error for non-200 responses

            token_json = refresh_response.json()
            access_token = token_json.get("access_token")
            new_refresh_token = token_json.get("refresh_token", refresh_token)  # Use new if provided

            if not access_token:
                return GlobalConstants.return_api_response(
                    message="Access token not found in the response",
                    result=[],
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
                )


            return GlobalConstants.return_api_response(
                message=GlobalConstants.api_response_messages.success,
                result=[{"access_token": access_token, "refresh_token": new_refresh_token}],
                status_code=status.HTTP_200_OK
            )

        except httpx.TimeoutException:
            print("Error in Timeout : ", str(e))
            return GlobalConstants.return_api_response(
                message="Token refresh service timeout",
                result=None,
                status_code=status.HTTP_504_GATEWAY_TIMEOUT
            )

        except httpx.ConnectError as e:
            print("Error in ConnectError : ", str(e))
            return GlobalConstants.return_api_response(
                message="Failed to connect to token refresh service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except httpx.RequestError as e:
            print("Error in RequestError : ", str(e))
            return GlobalConstants.return_api_response(
                message="Error communicating with token refresh service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except Exception as e:
            print("Error in Exception : ", str(e))
            return GlobalConstants.return_api_response(
                message="An unexpected error occurred",
                result=None,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
            )




    """
    Logs out the user by redirecting to Microsoft's logout endpoint.
    """
    @staticmethod
    @router.get("/logout")
    async def logout_user():
        
        logout_url = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/logout?post_logout_redirect_uri={POST_LOGOUT_REDIRECT_URI}"

        try:
            async with httpx.AsyncClient(timeout=30) as client:
                logout_response = await client.get(logout_url)
                logout_response.raise_for_status()  # Raises an error for non-200 responses


            if logout_url:
                return RedirectResponse(url=logout_url)
            

        except httpx.TimeoutException:
            print("Error in Timeout : ", str(e))
            return GlobalConstants.return_api_response(
                message="Logout service timeout",
                result=None,
                status_code=status.HTTP_504_GATEWAY_TIMEOUT
            )

        except httpx.ConnectError as e:
            print("Error in ConnectError : ", str(e))
            return GlobalConstants.return_api_response(
                message="Failed to connect to logout service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except httpx.RequestError as e:
            print("Error in RequestError : ", str(e))
            return GlobalConstants.return_api_response(
                message="Error communicating with logout service",
                result=None,
                status_code=status.HTTP_502_BAD_GATEWAY
            )

        except Exception as e:
            print("Error in Exception : ", str(e))
            return GlobalConstants.return_api_response(
                message="An unexpected error occurred",
                result=None,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
            )
