import React, { useCallback, useEffect, useRef, useState } from 'react';

interface useFocusProps {
  container : React.MutableRefObject<HTMLElement | null>;
  onEscape? : () => void;
  focusInput? : boolean;
}

function useFocus({
  container,
  focusInput = true,
  onEscape,
} : useFocusProps) {
  const lastActive = useRef<HTMLElement | null>(null);
  const [focusable, setFocusable] = useState<HTMLElement[] | null>(null);

  const setLastActive = useCallback((element : HTMLElement | null) => {
    lastActive.current = element;
  }, []);

  const contain = useCallback(() => {
    if (!container.current) {
      setFocusable(null);
      return;
    }

    const contained = container.current.querySelectorAll<HTMLElement>(
      'a[href], button, input, select, textarea',
    );
    setFocusable(Array.from(contained));

    const continaedInputs = container.current.querySelectorAll<HTMLElement>(
      'input, select, textarea',
    );

    const first = focusInput
      ? continaedInputs[0] ?? contained[0]
      : contained[0];
    if (first) first.focus();
  }, [container, focusInput]);

  const takeFocus = useCallback(() => {
    setLastActive(document.activeElement as HTMLElement);
    contain();
  }, [setLastActive, contain]);

  const returnFocus = useCallback(() => {
    if (!lastActive.current) return;
    lastActive.current.focus({ preventScroll: true });
  }, [lastActive]);

  const handleTabKey = useCallback((event : KeyboardEvent) => {
    if (
      event.key === 'Tab' &&
      container.current &&
      focusable &&
      focusable.length
    ) {
      const enabledElements = focusable.filter(
        (i) => !('disabled' in i) || !i.disabled,
      );
      const firstElement = enabledElements[0];
      const lastElement = enabledElements[enabledElements.length - 1];

      if (!event.shiftKey && document.activeElement === lastElement) {
        firstElement.focus();
        event.preventDefault();
      }

      if (event.shiftKey && document.activeElement === firstElement) {
        lastElement.focus();
        event.preventDefault();
      }
    }
  }, [container, focusable]);

  const handleEscape = useCallback((event : KeyboardEvent) => {
    if (event.key === 'Escape') {
      if (onEscape) onEscape();
    }
  }, [onEscape]);

  useEffect(() => {
    if (!container.current) return;

    const current = container.current;
    container.current.addEventListener('keydown', handleTabKey);
    return () => {
      current?.removeEventListener('keydown', handleTabKey);
    };
  }, [container, handleTabKey]);

  useEffect(() => {
    document.addEventListener('keydown', handleEscape, false);
    return () => document.removeEventListener('keydown', handleEscape, false);
  }, [handleEscape])

  return {
    contain,
    takeFocus,
    returnFocus,
    setLastActive,
  };
}

export default useFocus;
