#!/usr/bin/env python3
"""
UK Energy Bill Analyzer - Basic Tests

These tests check that the server components work correctly.
Run with: python -m pytest tests/ -v

NOTE: These are unit tests that don't require a running server.
For integration tests, you would need to start the server first.
"""

import sys
import os

# Add parent directory to path so we can import server module
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import unittest
from unittest.mock import patch, MagicMock
import json


class TestErrorMessages(unittest.TestCase):
    """Test the error message system."""

    def test_error_messages_exist(self):
        """Check that all required error messages are defined."""
        from server import ERROR_MESSAGES

        required_keys = [
            'claude_key_missing',
            'claude_key_invalid',
            'epc_key_missing',
            'epc_key_invalid',
            'network_error',
            'timeout',
            'postcode_missing',
            'postcode_not_found',
            'no_certificates',
            'server_error',
            'service_unavailable'
        ]

        for key in required_keys:
            self.assertIn(key, ERROR_MESSAGES,
                         f"Missing error message: {key}")

    def test_error_messages_have_required_fields(self):
        """Check that each error message has title, message, and help."""
        from server import ERROR_MESSAGES

        for key, error in ERROR_MESSAGES.items():
            self.assertIn('title', error,
                         f"{key} missing 'title'")
            self.assertIn('message', error,
                         f"{key} missing 'message'")
            self.assertIn('help', error,
                         f"{key} missing 'help'")

    def test_create_error_response(self):
        """Test that create_error_response returns correct structure."""
        from server import create_error_response

        response = create_error_response('claude_key_missing')

        self.assertTrue(response['error'])
        self.assertEqual(response['errorKey'], 'claude_key_missing')
        self.assertIn('title', response)
        self.assertIn('message', response)
        self.assertIn('help', response)

    def test_create_error_response_with_details(self):
        """Test that details are included when provided."""
        from server import create_error_response

        response = create_error_response('server_error', 'Extra info')

        self.assertEqual(response['details'], 'Extra info')

    def test_create_error_response_unknown_key(self):
        """Test that unknown keys fall back to server_error."""
        from server import create_error_response

        response = create_error_response('unknown_error_key')

        # Should fall back to server_error
        self.assertEqual(response['title'], 'Something Went Wrong')


class TestAddressParsing(unittest.TestCase):
    """Test the address parsing and matching functions."""

    def test_extract_house_number(self):
        """Test extracting house numbers from addresses."""
        from server import extract_house_identifier

        # Standard house number
        id_type, id_value = extract_house_identifier("21 High Street")
        self.assertEqual(id_type, 'number')
        self.assertEqual(id_value, '21')

        # House number with comma
        id_type, id_value = extract_house_identifier("21, High Street")
        self.assertEqual(id_type, 'number')
        self.assertEqual(id_value, '21')

        # House number with letter
        id_type, id_value = extract_house_identifier("21A High Street")
        self.assertEqual(id_type, 'number')
        self.assertEqual(id_value, '21a')

    def test_extract_flat_number(self):
        """Test extracting flat numbers from addresses."""
        from server import extract_house_identifier

        id_type, id_value = extract_house_identifier("Flat 5, High Street")
        self.assertEqual(id_type, 'flat')
        self.assertEqual(id_value, '5')

    def test_extract_no_identifier(self):
        """Test when no identifier can be extracted."""
        from server import extract_house_identifier

        id_type, id_value = extract_house_identifier("")
        self.assertIsNone(id_type)
        self.assertIsNone(id_value)

        id_type, id_value = extract_house_identifier(None)
        self.assertIsNone(id_type)
        self.assertIsNone(id_value)

    def test_match_address_exact(self):
        """Test exact address matching."""
        from server import match_address

        certificates = [
            {'address': '21 High Street', 'fullAddress': '21 High Street, London'},
            {'address': '22 High Street', 'fullAddress': '22 High Street, London'}
        ]

        idx, score, reason = match_address("21 High Street", certificates)
        self.assertEqual(idx, 0)
        self.assertGreater(score, 0)

    def test_match_address_no_address(self):
        """Test matching with no address provided."""
        from server import match_address

        certificates = [
            {'address': '21 High Street', 'fullAddress': '21 High Street, London'}
        ]

        idx, score, reason = match_address("", certificates)
        self.assertEqual(idx, 0)  # Should default to first
        self.assertEqual(score, 0)


class TestEPCParsers(unittest.TestCase):
    """Test the HTML parsers for EPC data."""

    def test_epc_search_parser_extracts_certificates(self):
        """Test that EPCSearchParser extracts certificate numbers."""
        from server import EPCSearchParser

        # Sample HTML with a certificate link
        html = '''
        <a href="/energy-certificate/1234-5678-9012-3456-7890">
            21 High Street
        </a>
        '''

        parser = EPCSearchParser()
        parser.feed(html)
        parser.close()

        self.assertEqual(len(parser.certificates), 1)
        self.assertEqual(parser.certificates[0]['certificateNumber'],
                        '1234-5678-9012-3456-7890')

    def test_epc_improvement_parser_extracts_improvements(self):
        """Test that EPCImprovementParser extracts improvement data."""
        from server import EPCImprovementParser

        # Sample HTML with improvement data
        html = '''
        <h2>Steps you could take to save energy</h2>
        <h3>Step 1: Add loft insulation</h3>
        <dt>Cost</dt>
        <dd>100 - 200</dd>
        <dt>Saving</dt>
        <dd>50 per year</dd>
        <dt>Rating after</dt>
        <dd>C (71)</dd>
        '''

        parser = EPCImprovementParser()
        parser.feed(html)

        # Note: The parser requires specific HTML structure
        # This is a simplified test


class TestEnvLoading(unittest.TestCase):
    """Test environment variable loading."""

    def test_load_env_file_returns_false_when_missing(self):
        """Test that load_env_file returns False when .env is missing."""
        from server import load_env_file

        # Create a temporary path that doesn't exist
        with patch('server.Path') as mock_path:
            mock_path.return_value.__truediv__.return_value.exists.return_value = False
            # Since the function uses global import, this test is limited


class TestClassifyError(unittest.TestCase):
    """Test the error classification function."""

    def test_classify_timeout_error(self):
        """Test that timeout errors are classified correctly."""
        from server import classify_error
        import socket

        error = socket.timeout("Connection timed out")
        error_key, details = classify_error(error)

        self.assertEqual(error_key, 'timeout')

    def test_classify_connection_error(self):
        """Test that connection errors are classified correctly."""
        from server import classify_error

        error = ConnectionError("Connection refused")
        error_key, details = classify_error(error)

        self.assertEqual(error_key, 'network_error')

    def test_classify_generic_error(self):
        """Test that generic errors fall back to server_error."""
        from server import classify_error

        error = Exception("Something random happened")
        error_key, details = classify_error(error)

        self.assertEqual(error_key, 'server_error')


class TestPlainEnglish(unittest.TestCase):
    """Test that error messages follow plain English guidelines."""

    def test_messages_are_not_too_long(self):
        """Check that messages are reasonably short."""
        from server import ERROR_MESSAGES

        max_message_length = 200  # Characters

        for key, error in ERROR_MESSAGES.items():
            self.assertLessEqual(
                len(error['message']),
                max_message_length,
                f"Message for '{key}' is too long ({len(error['message'])} chars)"
            )

    def test_messages_dont_use_jargon(self):
        """Check that messages avoid technical jargon."""
        from server import ERROR_MESSAGES

        # Words that should be avoided in user-facing messages
        jargon_words = [
            'exception', 'error code', 'stack trace', 'null',
            'undefined', 'syntax', 'parse', 'invalid token'
        ]

        for key, error in ERROR_MESSAGES.items():
            message_lower = error['message'].lower()
            for jargon in jargon_words:
                self.assertNotIn(
                    jargon,
                    message_lower,
                    f"Message for '{key}' contains jargon: '{jargon}'"
                )


if __name__ == '__main__':
    # Run tests with verbose output
    unittest.main(verbosity=2)
