import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.TestWatcher;
import org.openqa.selenium.WebDriver;

import java.util.Optional;

/**
 * JUnit 5 extension for automatic screenshot capture on test failures.
 * Use @ExtendWith annotation to register this extension.
 *
 * Usage:
 * @ExtendWith(JUnitScreenshotWatcher.class)
 * public class MyTest {
 *     private WebDriver driver;
 *     // Test methods
 * }
 *
 * Note: Your test class must have a 'driver' field or store driver in ExtensionContext
 */
public class JUnitScreenshotWatcher implements TestWatcher {

    /**
     * Store namespace for WebDriver in ExtensionContext
     */
    private static final ExtensionContext.Namespace NAMESPACE =
            ExtensionContext.Namespace.create("screenshot-watcher");

    @Override
    public void testSuccessful(ExtensionContext context) {
        ScreenshotConfig.log("Test passed: " + context.getDisplayName());

        // Optional: Capture screenshots for successful critical tests
        if (!ScreenshotConfig.CAPTURE_ON_FAILURE_ONLY) {
            captureScreenshot(context, "PASSED");
        }
    }

    @Override
    public void testFailed(ExtensionContext context, Throwable cause) {
        ScreenshotConfig.log("Test failed: " + context.getDisplayName());
        ScreenshotConfig.log("Failure cause: " + cause.getMessage());

        if (ScreenshotConfig.isEnabled()) {
            captureScreenshot(context, "FAILED");
        }
    }

    @Override
    public void testAborted(ExtensionContext context, Throwable cause) {
        ScreenshotConfig.log("Test aborted: " + context.getDisplayName());
    }

    @Override
    public void testDisabled(ExtensionContext context, Optional<String> reason) {
        ScreenshotConfig.log("Test disabled: " + context.getDisplayName() +
                (reason.isPresent() ? " - Reason: " + reason.get() : ""));
    }

    /**
     * Capture screenshot for a test
     *
     * @param context Extension context
     * @param status  Test status (PASSED/FAILED)
     */
    private void captureScreenshot(ExtensionContext context, String status) {
        try {
            WebDriver driver = getDriver(context);

            if (driver != null) {
                String className = context.getTestClass()
                        .map(Class::getSimpleName)
                        .orElse("UnknownClass");

                String methodName = context.getTestMethod()
                        .map(method -> method.getName())
                        .orElse("unknownMethod");

                String screenshotName = className + "_" + methodName + "_" + status;

                String path = ScreenshotUtil.captureBrowserScreenshot(
                        driver,
                        screenshotName,
                        ScreenshotConfig.getSuiteDirectory(className)
                );

                if (path != null) {
                    ScreenshotConfig.log("Screenshot saved: " + path);
                    // Store screenshot path in context for potential retrieval
                    context.getStore(NAMESPACE).put("screenshotPath", path);
                } else {
                    System.err.println("Failed to capture screenshot for: " + methodName);
                }
            } else {
                System.err.println("WebDriver not found for test: " +
                        context.getDisplayName());
            }
        } catch (Exception e) {
            System.err.println("Exception while capturing screenshot: " + e.getMessage());
            e.printStackTrace();
        }
    }

    /**
     * Get WebDriver instance from test context
     * Tries multiple approaches to find the driver
     *
     * @param context Extension context
     * @return WebDriver instance, or null if not found
     */
    private WebDriver getDriver(ExtensionContext context) {
        // Approach 1: Try to get from ExtensionContext store
        WebDriver driver = context.getStore(NAMESPACE).get("driver", WebDriver.class);
        if (driver != null) {
            return driver;
        }

        // Approach 2: Try to get from test instance field
        Optional<Object> testInstance = context.getTestInstance();
        if (testInstance.isPresent()) {
            try {
                // Try to find 'driver' field in test class
                java.lang.reflect.Field driverField = testInstance.get()
                        .getClass()
                        .getDeclaredField("driver");

                driverField.setAccessible(true);
                Object fieldValue = driverField.get(testInstance.get());

                if (fieldValue instanceof WebDriver) {
                    return (WebDriver) fieldValue;
                }
            } catch (NoSuchFieldException e) {
                // Try parent class
                try {
                    java.lang.reflect.Field driverField = testInstance.get()
                            .getClass()
                            .getSuperclass()
                            .getDeclaredField("driver");

                    driverField.setAccessible(true);
                    Object fieldValue = driverField.get(testInstance.get());

                    if (fieldValue instanceof WebDriver) {
                        return (WebDriver) fieldValue;
                    }
                } catch (Exception ex) {
                    // Field not found
                }
            } catch (Exception e) {
                // Field not accessible
            }

            // Approach 3: Try getDriver() method if available
            try {
                java.lang.reflect.Method getDriverMethod = testInstance.get()
                        .getClass()
                        .getMethod("getDriver");

                Object result = getDriverMethod.invoke(testInstance.get());

                if (result instanceof WebDriver) {
                    return (WebDriver) result;
                }
            } catch (Exception e) {
                // Method not found or not accessible
            }
        }

        return null;
    }

    /**
     * Store WebDriver in ExtensionContext for later retrieval
     * Call this method in your @BeforeEach method
     *
     * @param context Extension context
     * @param driver  WebDriver instance to store
     */
    public static void storeDriver(ExtensionContext context, WebDriver driver) {
        context.getStore(NAMESPACE).put("driver", driver);
    }

    /**
     * Get stored screenshot path from context
     *
     * @param context Extension context
     * @return Screenshot path, or null if not found
     */
    public static String getScreenshotPath(ExtensionContext context) {
        return context.getStore(NAMESPACE).get("screenshotPath", String.class);
    }
}
