import { Dispatch, useEffect, useMemo, useRef, useState } from 'react';
import { TreeApi as ArboristTreeApi } from 'react-arborist';
import { getNode, getPathToNode, updateDescendantSelections } from './utils';
import {
  Actions,
  rootReducer,
} from 'react-arborist/dist/module/state/root-reducer';
import { NodeData, TNode } from './Node';

export type SelectionDescendantMap = Map<string, SelectionDescendantMap>;

type ExtendedActions = Actions | { type: 'OPEN_ALL' } | { type: 'CLOSE_ALL' };

export type TreeData = NodeData & {
  children?: TreeData[];
};

export type TreeApiBaseProps = {
  data: TreeData[];
  openByDefault?: boolean;
  initialSelections?: string[];
};
type AllOrNone<T> = Required<T> | Partial<Record<keyof T, undefined>>;
export type TreeApiControlProps = AllOrNone<{
  value: string[];
  onChange: (value: string[]) => void;
}>;
export type TreeApiProps = TreeApiBaseProps & TreeApiControlProps;

export const useTreeApi = ({
  data,
  openByDefault = false,
  initialSelections,
  value,
  onChange,
}: TreeApiProps) => {
  const treeRef = useRef<ArboristTreeApi<TreeData>>();
  /** ================================
   * Selection
   ================================ */
  const [internalSelectedIds, setInternalSelectedIds] = useState<string[]>(
    initialSelections || []
  );

  const selectedIds = value || internalSelectedIds;

  useEffect(() => {
    if (onChange) onChange(selectedIds);
  }, [selectedIds, onChange]);

  const [indeterminateSelectedIds, setIndeterminateSelectedIds] = useState<
    string[]
  >([]);

  const selectionDescendantMap = useRef<SelectionDescendantMap>(new Map());

  const selectNode = (node: TNode) => {
    const { id } = node;

    if (value) {
      onChange?.([...selectedIds, id]);
    } else {
      setInternalSelectedIds((prevSelectedIds) => [...prevSelectedIds, id]);
    }
  };

  const deselectNode = (node: TNode) => {
    const { id } = node;

    if (value) {
      onChange?.(selectedIds.filter((selectedId) => selectedId !== id));
    } else {
      setInternalSelectedIds((prevSelectedIds) =>
        prevSelectedIds.filter((selectedId) => selectedId !== id)
      );
    }
  };

  const prevSelectionIds = useRef<string[]>([]);
  useEffect(() => {
    const addedIds = selectedIds.filter(
      (id) => !prevSelectionIds.current.includes(id)
    );
    const removedIds = prevSelectionIds.current.filter(
      (id) => !selectedIds.includes(id)
    );
    prevSelectionIds.current = selectedIds;

    const rootNode = treeRef.current?.root;
    if (!rootNode) return;

    addedIds.forEach((id) => {
      const pathToNode = getPathToNode(id);
      const node = getNode({ nodes: rootNode.children, pathToNode });
      if (!node) return;
      setIndeterminateSelectedIds(() =>
        updateDescendantSelections({
          operation: 'INSERT',
          node,
          selectionDescendantMap: selectionDescendantMap.current,
        })
      );
    });

    removedIds.forEach((id) => {
      const pathToNode = getPathToNode(id);
      const node = getNode({ nodes: rootNode.children, pathToNode });
      if (!node) return;
      setIndeterminateSelectedIds(() =>
        updateDescendantSelections({
          operation: 'REMOVE',
          node,
          selectionDescendantMap: selectionDescendantMap.current,
          selectedIds,
        })
      );
    });
  }, [selectedIds]);

  const clearAll = () => {
    setInternalSelectedIds([]);
  };

  const toggleNodeSelection = (id: string) => {
    const pathToNode = getPathToNode(id);
    const rootNode = treeRef.current?.root;
    const node = getNode({ nodes: rootNode?.children || null, pathToNode });

    if (!node) return;

    const { data } = node;

    const isSelected = selectedIds.includes(id);

    /** ======== Select/Deselect Node ======== */
    if (!isSelected) selectNode(node);
    else deselectNode(node);

    /** ======== Select LinkedIDs ======== */
    data.linkedIds?.forEach((linkedId) => {
      const rootNode = treeRef.current?.root;

      if (!rootNode) return;

      const pathToNode = getPathToNode(linkedId);
      const linkedNode = getNode({ nodes: rootNode?.children, pathToNode });

      if (linkedNode) {
        if (!isSelected) selectNode(linkedNode);
        else deselectNode(linkedNode);
      }
    });
  };

  /** ================================
   * Visibility
   ================================ */
  const allNonLeafNodeIds = useMemo(() => {
    const getChildIds = (nodes: TreeData[]): string[] => {
      return nodes.flatMap((node) => {
        if (node.children) return [node.id, ...getChildIds(node.children)];
        return [];
      });
    };
    return getChildIds(data);
  }, [data]);

  /** ======== Extend Reducer ======== */
  useEffect(() => {
    const treeApi = treeRef.current;
    if (treeApi) {
      treeApi.store.replaceReducer((state, action) => {
        if (!state) return rootReducer(state, action);

        if (
          ['OPEN_ALL', 'CLOSE_ALL'].includes((action as ExtendedActions).type)
        ) {
          const isFiltered = treeApi.isFiltered;
          const openByDefault = treeApi.props.openByDefault;

          const isOpen = (action as ExtendedActions).type === 'OPEN_ALL';

          const clearFilterState =
            (isFiltered && isOpen) ||
            (!isFiltered && openByDefault && isOpen) ||
            (!isFiltered && !openByDefault && !isOpen);

          const filteredState = clearFilterState
            ? {}
            : allNonLeafNodeIds.reduce<{
                [key: string]: boolean;
              }>((acc, id) => {
                acc[id] = isOpen;
                return acc;
              }, {});

          const newState = {
            ...state,
            nodes: {
              ...state.nodes,
              open: {
                ...state.nodes.open,
                ...(isFiltered
                  ? { filtered: filteredState }
                  : { unfiltered: filteredState }),
              },
            },
          };
          return newState;
        } else {
          return rootReducer(state, action);
        }
      });
    }
  }, [allNonLeafNodeIds]);

  /** ======== Expand/Collapse ======== */
  const [isExpanded, setIsExpanded] = useState<boolean>(openByDefault);
  const openAll = () => {
    (treeRef.current?.store.dispatch as Dispatch<ExtendedActions>)({
      type: 'OPEN_ALL',
    });
  };

  const closeAll = () => {
    (treeRef.current?.store.dispatch as Dispatch<ExtendedActions>)({
      type: 'CLOSE_ALL',
    });
  };
  const expand = () => {
    setIsExpanded(true);
    openAll();
  };
  const collapse = () => {
    setIsExpanded(false);
    closeAll();
  };

  /** ================================
   * Search
   ================================ */
  const [search, setSearch] = useState<string>('');
  const onSearch = (search: string) => {
    setSearch(search);
    setIsExpanded(search.length > 0);
  };

  return {
    treeRef,
    selectedIds,
    indeterminateSelectedIds,
    toggleNodeSelection,
    isExpanded,
    expand,
    collapse,
    search,
    onSearch,
    clearAll,
  };
};

export type TreeApi = ReturnType<typeof useTreeApi>;
