#!/usr/bin/env python3
"""
AttestLayer Offline Verification Kit
=====================================

Standalone verifier for AttestLayer Registry entries.
Requires Python 3.9+ with NO external dependencies.

Usage:
    python verify.py --entry entry.json
    python verify.py --log 2025-01-28.ndjson
    python verify.py --checkpoint checkpoint.json
    python verify.py --chain 2025-01-28.ndjson

Author: AttestLayer
Spec: REG-1.0
"""

import argparse
import base64
import hashlib
import json
import os
import sys
from typing import List, Dict, Any, Optional, Tuple

# =============================================================================
# EMBEDDED CRYPTO (pure Python Ed25519 verify only)
# =============================================================================

# Ed25519 constants
B = 256
Q = 2**255 - 19
L = 2**252 + 27742317777372353535851937790883648493
D = -121665 * pow(121666, Q - 2, Q) % Q
I = pow(2, (Q - 1) // 4, Q)

def sha512(m: bytes) -> bytes:
    return hashlib.sha512(m).digest()

def sha256_hex(data: bytes) -> str:
    return hashlib.sha256(data).hexdigest()

def expmod(b, e, m):
    if e == 0:
        return 1
    t = expmod(b, e // 2, m) ** 2 % m
    if e & 1:
        t = (t * b) % m
    return t

def inv(x):
    return expmod(x, Q - 2, Q)

def xrecover(y):
    xx = (y * y - 1) * inv(D * y * y + 1)
    x = expmod(xx, (Q + 3) // 8, Q)
    if (x * x - xx) % Q != 0:
        x = (x * I) % Q
    if x % 2 != 0:
        x = Q - x
    return x

BY = 4 * inv(5)
BX = xrecover(BY)
BASE = (BX % Q, BY % Q, 1, (BX * BY) % Q)

def edwards_add(P, Q_pt):
    x1, y1, z1, t1 = P
    x2, y2, z2, t2 = Q_pt
    a = (y1 - x1) * (y2 - x2) % Q
    b = (y1 + x1) * (y2 + x2) % Q
    c = t1 * 2 * D * t2 % Q
    dd = z1 * 2 * z2 % Q
    e = b - a
    f = dd - c
    g = dd + c
    h = b + a
    x3 = e * f
    y3 = g * h
    t3 = e * h
    z3 = f * g
    return (x3 % Q, y3 % Q, z3 % Q, t3 % Q)

def scalarmult(P, e):
    if e == 0:
        return (0, 1, 1, 0)
    Q_pt = scalarmult(P, e // 2)
    Q_pt = edwards_add(Q_pt, Q_pt)
    if e & 1:
        Q_pt = edwards_add(Q_pt, P)
    return Q_pt

def encode_point(P):
    x, y, z, _ = P
    zi = inv(z)
    x = (x * zi) % Q
    y = (y * zi) % Q
    bits = [(y >> i) & 1 for i in range(B - 1)] + [x & 1]
    return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(B // 8)])

def decode_point(s):
    y = sum(2**i * ((s[i // 8] >> (i % 8)) & 1) for i in range(B - 1))
    x = xrecover(y)
    if x & 1 != ((s[31] >> 7) & 1):
        x = Q - x
    P = (x, y, 1, (x * y) % Q)
    return P

def verify_ed25519(public_key: bytes, message: bytes, signature: bytes) -> bool:
    """Verify Ed25519 signature."""
    if len(public_key) != 32:
        return False
    if len(signature) != 64:
        return False
    
    try:
        R = decode_point(signature[:32])
        A = decode_point(public_key)
        S = sum(2**i * ((signature[32 + i // 8] >> (i % 8)) & 1) for i in range(B))
        
        h = sha512(signature[:32] + public_key + message)
        h_int = sum(2**i * ((h[i // 8] >> (i % 8)) & 1) for i in range(B * 2))
        
        sB = scalarmult(BASE, S)
        hA = scalarmult(A, h_int % L)
        
        # R + hA should equal sB
        R_plus_hA = edwards_add(R, hA)
        
        return encode_point(sB) == encode_point(R_plus_hA)
    except Exception:
        return False

# =============================================================================
# JWK HANDLING
# =============================================================================

def jwk_x_to_public_key(x_b64: str) -> bytes:
    """Convert JWK x parameter to raw Ed25519 public key bytes."""
    # Ensure proper padding
    padding = 4 - len(x_b64) % 4
    if padding != 4:
        x_b64 += '=' * padding
    return base64.urlsafe_b64decode(x_b64)

def load_jwks(filepath: str) -> Dict[str, bytes]:
    """Load JWKS file and return dict of kid -> public_key_bytes."""
    with open(filepath, 'r') as f:
        jwks = json.load(f)
    
    keys = {}
    for key in jwks.get("keys", []):
        if key.get("kty") == "OKP" and key.get("crv") == "Ed25519":
            kid = key.get("kid")
            x = key.get("x")
            if kid and x:
                keys[kid] = jwk_x_to_public_key(x)
    
    return keys

# =============================================================================
# CANONICAL JSON
# =============================================================================

def canonical_json(obj: Any) -> str:
    """Deterministic JSON encoding."""
    return json.dumps(obj, sort_keys=True, separators=(',', ':'), ensure_ascii=True)

def canonical_bytes(obj: Any) -> bytes:
    return canonical_json(obj).encode('utf-8')

# =============================================================================
# MERKLE TREE
# =============================================================================

def compute_internal_hash(left: str, right: str) -> str:
    """Compute internal node hash for Merkle tree."""
    combined = f"{left}:{right}"
    return hashlib.sha256(combined.encode('utf-8')).hexdigest()

def compute_merkle_root(hashes: List[str]) -> str:
    """Compute Merkle root from list of hashes."""
    if not hashes:
        return hashlib.sha256(b"EMPTY").hexdigest()
    
    # Sort for determinism
    sorted_hashes = sorted(hashes)
    
    if len(sorted_hashes) == 1:
        return sorted_hashes[0]
    
    # Build tree bottom-up
    current_level = sorted_hashes
    while len(current_level) > 1:
        next_level = []
        for i in range(0, len(current_level), 2):
            if i + 1 < len(current_level):
                next_level.append(compute_internal_hash(current_level[i], current_level[i + 1]))
            else:
                next_level.append(current_level[i])
        current_level = next_level
    
    return current_level[0]

def verify_merkle_proof(leaf_hash: str, proof: List[Dict], root: str) -> bool:
    """Verify a Merkle proof."""
    current = leaf_hash
    
    for step in proof:
        sibling = step.get("sibling")
        position = step.get("position")
        
        if position == "left":
            current = compute_internal_hash(sibling, current)
        else:
            current = compute_internal_hash(current, sibling)
    
    return current == root

# =============================================================================
# ENTRY VERIFICATION
# =============================================================================

def verify_entry_hash(entry: Dict) -> Tuple[bool, str]:
    """Verify entry_hash is correctly computed."""
    # Extract hash payload (entry without entry_hash and registry_sig)
    hash_payload = {
        "spec": entry.get("spec"),
        "registry_id": entry.get("registry_id"),
        "date": entry.get("date"),
        "offset": entry.get("offset"),
        "issuer": entry.get("issuer"),
        "lane": entry.get("lane"),
        "issued_at": entry.get("issued_at"),
        "received_at": entry.get("received_at"),
        "pubsub_message_id": entry.get("pubsub_message_id"),
        "artifact": entry.get("artifact"),
        "prev_entry_hash": entry.get("prev_entry_hash"),
        "issuer_attest": entry.get("issuer_attest"),
    }
    
    # Add optional fields
    if entry.get("receipt_ref"):
        hash_payload["receipt_ref"] = entry.get("receipt_ref")
    
    expected_hash = sha256_hex(canonical_bytes(hash_payload))
    actual_hash = entry.get("entry_hash")
    
    if expected_hash == actual_hash:
        return True, ""
    else:
        return False, f"entry_hash mismatch: expected {expected_hash[:16]}..., got {actual_hash[:16]}..."

def verify_signature(entry: Dict, registry_keys: Dict[str, bytes]) -> Tuple[bool, str]:
    """Verify registry signature on entry."""
    registry_sig = entry.get("registry_sig", {})
    kid = registry_sig.get("kid")
    sig_b64 = registry_sig.get("sig_b64")
    
    if not kid or not sig_b64:
        return False, "Missing registry signature"
    
    if kid not in registry_keys:
        return False, f"Unknown registry key: {kid}"
    
    public_key = registry_keys[kid]
    entry_hash = entry.get("entry_hash", "")
    
    try:
        signature = base64.b64decode(sig_b64)
        message = entry_hash.encode('utf-8')
        
        if verify_ed25519(public_key, message, signature):
            return True, ""
        else:
            return False, "Invalid registry signature"
    except Exception as e:
        return False, f"Signature verification error: {e}"

def verify_issuer_signature(entry: Dict, issuer_keys: Dict[str, bytes]) -> Tuple[bool, str]:
    """Verify issuer signature on the attestation."""
    issuer_attest = entry.get("issuer_attest", {})
    issuer = entry.get("issuer", {})
    
    kid = issuer.get("issuer_kid")
    event_hash = issuer_attest.get("event_hash")
    sig_b64 = issuer_attest.get("issuer_sig_b64")
    
    if not kid or not event_hash or not sig_b64:
        return False, "Missing issuer attestation"
    
    if kid not in issuer_keys:
        return False, f"Unknown issuer key: {kid}"
    
    public_key = issuer_keys[kid]
    
    try:
        signature = base64.b64decode(sig_b64)
        message = event_hash.encode('utf-8')
        
        if verify_ed25519(public_key, message, signature):
            return True, ""
        else:
            return False, "Invalid issuer signature"
    except Exception as e:
        return False, f"Issuer signature verification error: {e}"

# =============================================================================
# CHAIN VERIFICATION
# =============================================================================

def verify_chain(entries: List[Dict]) -> Tuple[bool, List[str]]:
    """Verify hash chain integrity."""
    issues = []
    
    if not entries:
        return True, []
    
    for i, entry in enumerate(entries):
        # Check offset sequence
        if entry.get("offset") != i:
            issues.append(f"Offset gap at index {i}: expected {i}, got {entry.get('offset')}")
        
        # Check chain link
        if i == 0:
            prev = entry.get("prev_entry_hash")
            if prev is not None and prev != "":
                issues.append(f"First entry has non-null prev_entry_hash: {prev[:16]}...")
        else:
            expected_prev = entries[i - 1].get("entry_hash")
            actual_prev = entry.get("prev_entry_hash")
            if expected_prev != actual_prev:
                issues.append(f"Chain break at offset {entry.get('offset')}: prev_entry_hash mismatch")
    
    return len(issues) == 0, issues

# =============================================================================
# CHECKPOINT VERIFICATION
# =============================================================================

def verify_checkpoint(checkpoint: Dict, entries: List[Dict], registry_keys: Dict[str, bytes]) -> Tuple[bool, List[str]]:
    """Verify checkpoint against entries."""
    issues = []
    
    # Verify checkpoint signature
    registry_sig = checkpoint.get("registry_sig", {})
    kid = registry_sig.get("kid")
    sig_b64 = registry_sig.get("sig_b64")
    checkpoint_hash = checkpoint.get("checkpoint_hash")
    
    if kid not in registry_keys:
        issues.append(f"Unknown registry key: {kid}")
    else:
        public_key = registry_keys[kid]
        try:
            signature = base64.b64decode(sig_b64)
            message = checkpoint_hash.encode('utf-8')
            if not verify_ed25519(public_key, message, signature):
                issues.append("Invalid checkpoint signature")
        except Exception as e:
            issues.append(f"Checkpoint signature verification error: {e}")
    
    # Verify Merkle root
    entry_range = checkpoint.get("range", {})
    start_offset = entry_range.get("start_offset", 0)
    end_offset = entry_range.get("end_offset", len(entries) - 1)
    
    range_entries = [e for e in entries if start_offset <= e.get("offset", -1) <= end_offset]
    entry_hashes = [e.get("entry_hash") for e in range_entries]
    
    computed_root = compute_merkle_root(entry_hashes)
    checkpoint_root = checkpoint.get("merkle", {}).get("root_sha256")
    
    if computed_root != checkpoint_root:
        issues.append(f"Merkle root mismatch: expected {computed_root[:16]}..., got {checkpoint_root[:16] if checkpoint_root else 'None'}...")
    
    # Verify entry count
    expected_count = checkpoint.get("range", {}).get("count", 0)
    if len(range_entries) != expected_count:
        issues.append(f"Entry count mismatch: expected {expected_count}, got {len(range_entries)}")
    
    return len(issues) == 0, issues

# =============================================================================
# MAIN VERIFIER
# =============================================================================

class Verifier:
    def __init__(self, kit_dir: str = "."):
        self.kit_dir = kit_dir
        self.registry_keys = self._load_keys("jwks/registry.jwks.json")
        self.issuer_keys = self._load_keys("jwks/issuer.jwks.json")
    
    def _load_keys(self, path: str) -> Dict[str, bytes]:
        full_path = os.path.join(self.kit_dir, path)
        if os.path.exists(full_path):
            return load_jwks(full_path)
        return {}
    
    def verify_entry_file(self, filepath: str) -> Tuple[bool, List[str]]:
        """Verify a single entry file."""
        with open(filepath, 'r') as f:
            entry = json.load(f)
        
        issues = []
        
        # Verify entry hash
        ok, msg = verify_entry_hash(entry)
        if not ok:
            issues.append(msg)
        
        # Verify registry signature
        if self.registry_keys:
            ok, msg = verify_signature(entry, self.registry_keys)
            if not ok:
                issues.append(msg)
        else:
            issues.append("Warning: No registry keys loaded, skipping signature verification")
        
        # Verify issuer signature
        if self.issuer_keys:
            ok, msg = verify_issuer_signature(entry, self.issuer_keys)
            if not ok:
                issues.append(msg)
        else:
            issues.append("Warning: No issuer keys loaded, skipping issuer signature verification")
        
        return len([i for i in issues if not i.startswith("Warning:")]) == 0, issues
    
    def verify_log_file(self, filepath: str) -> Tuple[bool, List[str]]:
        """Verify an NDJSON log file."""
        entries = []
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    entries.append(json.loads(line))
        
        all_issues = []
        
        # Verify each entry
        for entry in entries:
            ok, msg = verify_entry_hash(entry)
            if not ok:
                all_issues.append(f"Entry {entry.get('offset')}: {msg}")
            
            if self.registry_keys:
                ok, msg = verify_signature(entry, self.registry_keys)
                if not ok:
                    all_issues.append(f"Entry {entry.get('offset')}: {msg}")
        
        # Verify chain
        ok, chain_issues = verify_chain(entries)
        all_issues.extend(chain_issues)
        
        return len(all_issues) == 0, all_issues
    
    def verify_checkpoint_file(self, checkpoint_path: str, log_path: str = None) -> Tuple[bool, List[str]]:
        """Verify a checkpoint file."""
        with open(checkpoint_path, 'r') as f:
            checkpoint = json.load(f)
        
        entries = []
        if log_path and os.path.exists(log_path):
            with open(log_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        entries.append(json.loads(line))
        
        return verify_checkpoint(checkpoint, entries, self.registry_keys)

def main():
    parser = argparse.ArgumentParser(
        description="AttestLayer Offline Verification Kit",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python verify.py --entry entry.json
    python verify.py --log 2025-01-28.ndjson
    python verify.py --checkpoint checkpoint.json --log 2025-01-28.ndjson
    python verify.py --chain 2025-01-28.ndjson
        """
    )
    
    parser.add_argument("--entry", help="Verify a single entry JSON file")
    parser.add_argument("--log", help="Verify an NDJSON log file")
    parser.add_argument("--checkpoint", help="Verify a checkpoint file")
    parser.add_argument("--chain", help="Verify chain integrity of an NDJSON log file")
    parser.add_argument("--kit-dir", default=".", help="Path to verification kit directory")
    
    args = parser.parse_args()
    
    if not any([args.entry, args.log, args.checkpoint, args.chain]):
        parser.print_help()
        sys.exit(1)
    
    verifier = Verifier(args.kit_dir)
    
    exit_code = 0
    
    if args.entry:
        print(f"\n=== Verifying entry: {args.entry} ===")
        ok, issues = verifier.verify_entry_file(args.entry)
        if ok:
            print("✓ Entry verification PASSED")
        else:
            print("✗ Entry verification FAILED")
            for issue in issues:
                print(f"  - {issue}")
            exit_code = 1
    
    if args.log:
        print(f"\n=== Verifying log: {args.log} ===")
        ok, issues = verifier.verify_log_file(args.log)
        if ok:
            print("✓ Log verification PASSED")
        else:
            print("✗ Log verification FAILED")
            for issue in issues:
                print(f"  - {issue}")
            exit_code = 1
    
    if args.chain:
        print(f"\n=== Verifying chain: {args.chain} ===")
        entries = []
        with open(args.chain, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    entries.append(json.loads(line))
        
        ok, issues = verify_chain(entries)
        if ok:
            print(f"✓ Chain verification PASSED ({len(entries)} entries)")
        else:
            print("✗ Chain verification FAILED")
            for issue in issues:
                print(f"  - {issue}")
            exit_code = 1
    
    if args.checkpoint:
        print(f"\n=== Verifying checkpoint: {args.checkpoint} ===")
        log_path = args.log
        ok, issues = verifier.verify_checkpoint_file(args.checkpoint, log_path)
        if ok:
            print("✓ Checkpoint verification PASSED")
        else:
            print("✗ Checkpoint verification FAILED")
            for issue in issues:
                print(f"  - {issue}")
            exit_code = 1
    
    print()
    if exit_code == 0:
        print("=== ALL VERIFICATIONS PASSED ===")
    else:
        print("=== SOME VERIFICATIONS FAILED ===")
    
    sys.exit(exit_code)

if __name__ == "__main__":
    main()
