# billing.py - Complete billing and payment system
import stripe
import os
from flask import Blueprint, render_template, request, redirect, url_for, flash, jsonify
from flask_login import login_required, current_user
from models import db, Organization, Subscription, Invoice
from datetime import datetime, timedelta

# Initialize Stripe
stripe.api_key = os.getenv('STRIPE_SECRET_KEY')
STRIPE_PUBLISHABLE_KEY = os.getenv('STRIPE_PUBLISHABLE_KEY')

billing_bp = Blueprint('billing', __name__)

# Subscription plans configuration
SUBSCRIPTION_PLANS = {
    'free': {
        'name': 'Free',
        'price': 0,
        'price_id': None,
        'features': [
            '5 assessments per month',
            'Basic reporting',
            'Community support',
            'GitHub/GitLab integration'
        ],
        'limits': {
            'assessments_per_month': 5,
            'users': 3,
            'api_calls': 100
        }
    },
    'pro': {
        'name': 'Professional',
        'price': 29,
        'price_id': os.getenv('STRIPE_PRO_PRICE_ID'),
        'features': [
            'Unlimited assessments',
            'Advanced analytics',
            'Priority support',
            'API access',
            'Custom reporting',
            'Team collaboration'
        ],
        'limits': {
            'assessments_per_month': -1,  # Unlimited
            'users': 10,
            'api_calls': 1000
        }
    },
    'enterprise': {
        'name': 'Enterprise',
        'price': 299,
        'price_id': os.getenv('STRIPE_ENTERPRISE_PRICE_ID'),
        'features': [
            'Everything in Pro',
            'Multi-organization management',
            'SSO/SAML integration',
            'Compliance reports',
            'Dedicated support',
            'Custom integrations',
            'SLA guarantee'
        ],
        'limits': {
            'assessments_per_month': -1,
            'users': -1,  # Unlimited
            'api_calls': -1
        }
    }
}

@billing_bp.route('/billing')
@login_required
def billing_dashboard():
    org = current_user.organization
    subscription = Subscription.query.filter_by(organization_id=org.id).first()
    recent_invoices = Invoice.query.filter_by(organization_id=org.id).order_by(Invoice.created_at.desc()).limit(5).all()
    
    return render_template('billing/dashboard.html', 
                         organization=org,
                         subscription=subscription,
                         recent_invoices=recent_invoices,
                         plans=SUBSCRIPTION_PLANS,
                         stripe_publishable_key=STRIPE_PUBLISHABLE_KEY)

@billing_bp.route('/upgrade/<plan>')
@login_required
def upgrade_plan(plan):
    if plan not in SUBSCRIPTION_PLANS:
        flash('Invalid subscription plan', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    if plan == 'free':
        flash('You are already on the free plan', 'info')
        return redirect(url_for('billing.billing_dashboard'))
    
    org = current_user.organization
    
    # Check if already on this plan
    if org.subscription_plan == plan:
        flash(f'You are already on the {SUBSCRIPTION_PLANS[plan]["name"]} plan', 'info')
        return redirect(url_for('billing.billing_dashboard'))
    
    try:
        # Create or get Stripe customer
        if not org.stripe_customer_id:
            customer = stripe.Customer.create(
                name=org.name,
                email=current_user.email,
                metadata={
                    'organization_id': org.id,
                    'organization_name': org.name
                }
            )
            org.stripe_customer_id = customer.id
            db.session.commit()
        
        # Create Stripe Checkout session
        checkout_session = stripe.checkout.Session.create(
            customer=org.stripe_customer_id,
            payment_method_types=['card'],
            line_items=[{
                'price': SUBSCRIPTION_PLANS[plan]['price_id'],
                'quantity': 1,
            }],
            mode='subscription',
            success_url=url_for('billing.upgrade_success', _external=True) + '?session_id={CHECKOUT_SESSION_ID}',
            cancel_url=url_for('billing.billing_dashboard', _external=True),
            metadata={
                'organization_id': org.id,
                'plan': plan
            }
        )
        
        return redirect(checkout_session.url)
        
    except stripe.error.StripeError as e:
        flash(f'Payment processing error: {str(e)}', 'error')
        return redirect(url_for('billing.billing_dashboard'))

@billing_bp.route('/upgrade/success')
@login_required
def upgrade_success():
    session_id = request.args.get('session_id')
    
    if not session_id:
        flash('Invalid payment session', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    try:
        # Retrieve the checkout session
        checkout_session = stripe.checkout.Session.retrieve(session_id)
        
        if checkout_session.payment_status == 'paid':
            # Update organization subscription
            org = current_user.organization
            plan = checkout_session.metadata.get('plan')
            
            org.subscription_plan = plan
            org.subscription_status = 'active'
            org.stripe_subscription_id = checkout_session.subscription
            
            # Update limits based on plan
            if plan in SUBSCRIPTION_PLANS:
                org.monthly_assessments_limit = SUBSCRIPTION_PLANS[plan]['limits']['assessments_per_month']
            
            db.session.commit()
            
            flash(f'Successfully upgraded to {SUBSCRIPTION_PLANS[plan]["name"]} plan!', 'success')
        else:
            flash('Payment was not completed', 'error')
    
    except stripe.error.StripeError as e:
        flash(f'Error verifying payment: {str(e)}', 'error')
    
    return redirect(url_for('billing.billing_dashboard'))

@billing_bp.route('/cancel-subscription', methods=['POST'])
@login_required
def cancel_subscription():
    if not current_user.can_manage_org():
        flash('You do not have permission to cancel the subscription', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    org = current_user.organization
    
    if not org.stripe_subscription_id:
        flash('No active subscription to cancel', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    try:
        # Cancel the subscription at period end
        stripe.Subscription.modify(
            org.stripe_subscription_id,
            cancel_at_period_end=True
        )
        
        flash('Your subscription will be canceled at the end of the current billing period', 'info')
        
    except stripe.error.StripeError as e:
        flash(f'Error canceling subscription: {str(e)}', 'error')
    
    return redirect(url_for('billing.billing_dashboard'))

@billing_bp.route('/reactivate-subscription', methods=['POST'])
@login_required
def reactivate_subscription():
    if not current_user.can_manage_org():
        flash('You do not have permission to reactivate the subscription', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    org = current_user.organization
    
    if not org.stripe_subscription_id:
        flash('No subscription to reactivate', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    try:
        # Reactivate the subscription
        stripe.Subscription.modify(
            org.stripe_subscription_id,
            cancel_at_period_end=False
        )
        
        flash('Your subscription has been reactivated', 'success')
        
    except stripe.error.StripeError as e:
        flash(f'Error reactivating subscription: {str(e)}', 'error')
    
    return redirect(url_for('billing.billing_dashboard'))

@billing_bp.route('/download-invoice/<invoice_id>')
@login_required
def download_invoice(invoice_id):
    invoice = Invoice.query.filter_by(
        id=invoice_id,
        organization_id=current_user.organization.id
    ).first()
    
    if not invoice:
        flash('Invoice not found', 'error')
        return redirect(url_for('billing.billing_dashboard'))
    
    try:
        # Get invoice from Stripe
        stripe_invoice = stripe.Invoice.retrieve(invoice.stripe_invoice_id)
        
        # Redirect to Stripe hosted invoice URL
        return redirect(stripe_invoice.hosted_invoice_url)
        
    except stripe.error.StripeError as e:
        flash(f'Error retrieving invoice: {str(e)}', 'error')
        return redirect(url_for('billing.billing_dashboard'))

@billing_bp.route('/webhooks/stripe', methods=['POST'])
def stripe_webhook():
    payload = request.get_data()
    sig_header = request.headers.get('Stripe-Signature')
    endpoint_secret = os.getenv('STRIPE_WEBHOOK_SECRET')
    
    try:
        event = stripe.Webhook.construct_event(
            payload, sig_header, endpoint_secret
        )
    except ValueError:
        return 'Invalid payload', 400
    except stripe.error.SignatureVerificationError:
        return 'Invalid signature', 400
    
    # Handle different event types
    if event['type'] == 'customer.subscription.created':
        handle_subscription_created(event['data']['object'])
    elif event['type'] == 'customer.subscription.updated':
        handle_subscription_updated(event['data']['object'])
    elif event['type'] == 'customer.subscription.deleted':
        handle_subscription_deleted(event['data']['object'])
    elif event['type'] == 'invoice.payment_succeeded':
        handle_payment_succeeded(event['data']['object'])
    elif event['type'] == 'invoice.payment_failed':
        handle_payment_failed(event['data']['object'])
    
    return 'Success', 200

def handle_subscription_created(subscription):
    """Handle subscription creation webhook"""
    customer_id = subscription['customer']
    
    # Find organization by customer ID
    org = Organization.query.filter_by(stripe_customer_id=customer_id).first()
    if not org:
        return
    
    # Update organization
    org.stripe_subscription_id = subscription['id']
    org.subscription_status = subscription['status']
    org.subscription_ends_at = datetime.fromtimestamp(subscription['current_period_end'])
    
    db.session.commit()

def handle_subscription_updated(subscription):
    """Handle subscription update webhook"""
    org = Organization.query.filter_by(
        stripe_subscription_id=subscription['id']
    ).first()
    
    if not org:
        return
    
    org.subscription_status = subscription['status']
    org.subscription_ends_at = datetime.fromtimestamp(subscription['current_period_end'])
    
    db.session.commit()

def handle_subscription_deleted(subscription):
    """Handle subscription cancellation webhook"""
    org = Organization.query.filter_by(
        stripe_subscription_id=subscription['id']
    ).first()
    
    if not org:
        return
    
    org.subscription_plan = 'free'
    org.subscription_status = 'canceled'
    org.monthly_assessments_limit = 5
    
    db.session.commit()

def handle_payment_succeeded(invoice):
    """Handle successful payment webhook"""
    customer_id = invoice['customer']
    
    org = Organization.query.filter_by(stripe_customer_id=customer_id).first()
    if not org:
        return
    
    # Create invoice record
    invoice_record = Invoice(
        organization_id=org.id,
        stripe_invoice_id=invoice['id'],
        amount_paid=invoice['amount_paid'],
        amount_due=invoice['amount_due'],
        currency=invoice['currency'],
        status='paid',
        invoice_date=datetime.fromtimestamp(invoice['created']),
        paid_at=datetime.fromtimestamp(invoice['status_transitions']['paid_at']) if invoice['status_transitions']['paid_at'] else None
    )
    
    db.session.add(invoice_record)
    
    # Reset monthly usage if it's a new billing period
    org.monthly_assessments_used = 0
    
    db.session.commit()

def handle_payment_failed(invoice):
    """Handle failed payment webhook"""
    customer_id = invoice['customer']
    
    org = Organization.query.filter_by(stripe_customer_id=customer_id).first()
    if not org:
        return
    
    # Update subscription status
    if org.subscription_status != 'past_due':
        org.subscription_status = 'past_due'
        db.session.commit()
    
    # Send notification email (implement later)
    # send_payment_failed_email(org)

# Usage tracking decorators
def track_assessment_usage(f):
    """Decorator to track assessment usage"""
    def decorated_function(*args, **kwargs):
        if current_user.is_authenticated:
            org = current_user.organization
            
            # Check if user can create assessment
            if not org.can_create_assessment():
                flash('You have reached your monthly assessment limit. Please upgrade your plan.', 'warning')
                return redirect(url_for('billing.billing_dashboard'))
            
            # Execute the function
            result = f(*args, **kwargs)
            
            # Increment usage count
            org.monthly_assessments_used += 1
            db.session.commit()
            
            return result
        
        return f(*args, **kwargs)
    
    return decorated_function
