"""
Service for interacting with BRC Record Cleaner API.
"""
import logging
import requests
from typing import List, Dict, Any, Tuple
from datetime import datetime, timedelta


logger = logging.getLogger(__name__)

PAGE_SIZE=100

# Record Cleaner API endpoints
RECORD_CLEANER_BASE_URL = 'https://record-cleaner.brc.ac.uk'
TOKEN_ENDPOINT = f'{RECORD_CLEANER_BASE_URL}/token'
VALIDATE_ENDPOINT = f'{RECORD_CLEANER_BASE_URL}/validate?verbose=0'
VERIFY_ENDPOINT = f'{RECORD_CLEANER_BASE_URL}/verify?verbose=0'


class RecordCleanerClient:
    """
    Client for interacting with the BRC Record Cleaner Service.
    Handles authentication, validation, and verification of species records.
    """

    def __init__(self, username: str, password: str):
        """
        Initialize the Record Cleaner client.

        Args:
            username: Record Cleaner username
            password: Record Cleaner password
        """
        self.username = username
        self.password = password
        self.token = None
        self.token_expires_at = None

    def _get_token(self) -> str:
        """
        Get a valid JWT token, refreshing if necessary.

        Returns:
            Valid JWT token

        Raises:
            Exception: If authentication fails
        """
        # Check if we have a valid token
        if self.token and self.token_expires_at:
            # Refresh token 1 minute before expiry
            if datetime.now() < self.token_expires_at - timedelta(minutes=1):
                return self.token

        # Get new token
        try:
            response = requests.post(
                TOKEN_ENDPOINT,
                data={
                    'username': self.username,
                    'password': self.password,
                    'grant_type': 'password'
                },
                timeout=10
            )
            response.raise_for_status()

            data = response.json()
            self.token = data['access_token']
            # Token expires in 15 minutes
            self.token_expires_at = datetime.now() + timedelta(minutes=15)

            return self.token

        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to authenticate with Record Cleaner: {str(e)}", exc_info=True)
            raise Exception(f"BRC Record Cleaner API error: Failed to authenticate. {str(e)}")

    def validate(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Validate records for format, spatial references, dates, and taxon names.
        Processes records in batches of PAGE_SIZE if necessary.

        Args:
            records: List of records to validate in Record Cleaner format

        Returns:
            List of validated records with result and messages fields

        Raises:
            Exception: If validation request fails
        """
        if not records:
            return []

        all_results = []
        
        # Process records in batches of PAGE_SIZE
        for i in range(0, len(records), PAGE_SIZE):
            batch = records[i:i + PAGE_SIZE]
            logger.info(f"Validating batch {i // PAGE_SIZE + 1}: records {i+1} to {i+len(batch)}")
            
            token = self._get_token()

            try:
                response = requests.post(
                    VALIDATE_ENDPOINT,
                    headers={
                        'Authorization': f'Bearer {token}',
                        'Content-Type': 'application/json'
                    },
                    json=batch,
                    timeout=120
                )
                response.raise_for_status()

                batch_results = response.json()
                all_results.extend(batch_results)

            except requests.exceptions.RequestException as e:
                logger.error(f"Failed to validate batch starting at record {i+1}: {str(e)}", exc_info=True)
                raise Exception(f"BRC Record Cleaner API error: Failed to validate records (batch starting at record {i+1}). {str(e)}")

        return all_results

    def verify(
        self,
        records: List[Dict[str, Any]],
        verbose: int = 1
    ) -> Dict[str, Any]:
        """
        Verify records against biological rules (distribution, phenology, etc.).
        Processes records in batches of PAGE_SIZE if necessary.

        Args:
            records: List of records to verify in Record Cleaner format
            verbose: 0 to suppress ID difficulty messages, 1 to include (default: 1)

        Returns:
            Dictionary containing 'records' array and metadata

        Raises:
            Exception: If verification request fails
        """
        if not records:
            return {'records': [], 'duration_ns': 0}

        all_records = []        
        total_duration = 0
        
        # Process records in batches of PAGE_SIZE
        for i in range(0, len(records), PAGE_SIZE):
            batch = records[i:i + PAGE_SIZE]
            logger.info(f"Verifying batch {i // PAGE_SIZE + 1}: records {i+1} to {i+len(batch)}")
            
            token = self._get_token()

            try:
                response = requests.post(
                    VERIFY_ENDPOINT,
                    headers={
                        'Authorization': f'Bearer {token}',
                        'Content-Type': 'application/json'
                    },
                    json={'records': batch},
                    params={'verbose': verbose},
                    timeout=120
                )
                response.raise_for_status()

                batch_result = response.json()
                all_records.extend(batch_result.get('records', []))
                
                
                total_duration += batch_result.get('duration_ns', 0)

            except requests.exceptions.RequestException as e:
                logger.error(f"Failed to verify batch starting at record {i+1}: {str(e)}", exc_info=True)
                raise Exception(f"BRC Record Cleaner API error: Failed to verify records (batch starting at record {i+1}). {str(e)}")

        return {
            'records': all_records,            
            'duration_ns': total_duration
        }

    def validate_and_verify(
        self,
        records: List[Dict[str, Any]]
    ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        """
        Validate records, then verify those that passed validation.

        Args:
            records: List of records to process in Record Cleaner format

        Returns:
            Tuple of (validated_records, verified_records)
        """
        # Step 1: Validate all records
        validated = self.validate(records)

        # Step 2: Extract records that passed validation for verification
        to_verify = []
        for record in validated:
            if record.get('result') == 'pass' or record.get('result') == 'warn':
                # Only send required fields to verify
                verify_record = {
                    'id': record['id'],
                    'date': record['date'],
                    'sref': record['sref']
                }

                # Add tvk if available
                if record.get('tvk'):
                    verify_record['tvk'] = record['tvk']
                elif record.get('preferred_tvk'):
                    verify_record['tvk'] = record['preferred_tvk']

                # Add name if available (as fallback)
                if record.get('name'):
                    verify_record['name'] = record['name']

                # Add stage if available
                if record.get('stage'):
                    verify_record['stage'] = record['stage']

                to_verify.append(verify_record)

        # Step 3: Verify records that passed validation
        verified = []
        if to_verify:
            verify_result = self.verify(to_verify)
            verified = verify_result.get('records', [])

        logger.info(f"Record Cleaner: {len(validated)} validated, {len(verified)} verified")

        for record in verified:
            messages = record.get("messages", [])
            messages = [
                m for m in messages
                if "difficulty:" not in m.lower()
            ]
            record["messages"] = messages
        
        for record in validated:
            messages = record.get("messages", [])
            messages = [
                m for m in messages
                if "difficulty:" not in m.lower()
            ]
            record["messages"] = messages

        return validated, verified


def generate_summary_report(
    validated: List[Dict[str, Any]],
    verified: List[Dict[str, Any]]
) -> Dict[str, Any]:
    """
    Generate a summary report from validation and verification results.

    Args:
        validated: List of validated records
        verified: List of verified records

    Returns:
        Summary statistics dictionary
    """
    # Count validation results
    validation_counts = {
        'total': len(validated),
        'pass': sum(1 for r in validated if r.get('result') == 'pass'),
        'warn': sum(1 for r in validated if r.get('result') == 'warn'),
        'fail': sum(1 for r in validated if r.get('result') == 'fail')
    }

    # Count verification results
    verification_counts = {
        'total': len(verified),
        'pass': sum(1 for r in verified if r.get('result') == 'pass'),
        'warn': sum(1 for r in verified if r.get('result') == 'warn'),
        'fail': sum(1 for r in verified if r.get('result') == 'fail')
    }

    # Calculate percentages
    validation_pass_rate = (
        (validation_counts['pass'] / validation_counts['total'] * 100)
        if validation_counts['total'] > 0 else 0
    )

    verification_pass_rate = (
        (verification_counts['pass'] / verification_counts['total'] * 100)
        if verification_counts['total'] > 0 else 0
    )

    return {
        'validation': validation_counts,
        'verification': verification_counts,
        'validation_pass_rate': round(validation_pass_rate, 1),
        'verification_pass_rate': round(verification_pass_rate, 1),
        'not_verified_count': validation_counts['total'] - verification_counts['total']
    }
