import { useQueries, useQueryClient } from '@tanstack/react-query'
import { getDataset } from 'api/dgs'
import useDgsClient from 'api/hooks/useDgsClient'
import { useCallback, useMemo, useState } from 'react'
import { ColumnLineage, Dataset } from 'types/dgs'
import { transformColumnLineage } from './LineageTransformers'
import { Edge } from 'reactflow'
import { FlowNode } from 'components/ReactFlow/LineageGraph'

export const useColumnLineage = () => {
  const [ids, setIds] = useState<Set<string>>(new Set())
  const client = useDgsClient()

  const queries = Array.from(ids).map((id) => {
    return {
      queryKey: ['dataset', id],
      queryFn: async () => getDataset(client, id),
      refetchOnMount: false,
      staleTime: 5 * 60 * 1000
    }
  })

  const results = useQueries({ queries })
  const successfulResults = results.filter(
    (result) => result.status === 'success'
  )
  const combinedResults = mapEdges(
    successfulResults.map((result) => result.data as Dataset)
  )
  const transformedResults = useMemo(
    () => transformColumnLineage(combinedResults),
    [combinedResults]
  )

  const updateIds = useCallback(
    (nodeId: string, foundEdges: Edge[], nodes: FlowNode[]) => {
      setIds((prevIds) => {
        const updatedIds = new Set(prevIds)
        updatedIds.add(nodeId)

        const datasetNodes = nodes.filter((node) => node.type === 'DATASET')
        const datasetNodesMap = new Map(
          datasetNodes.map((node) => [node.id, node])
        )

        foundEdges.forEach((edge) => {
          const { source, target } = edge
          const foundSourceNode = datasetNodesMap.get(source)
          const foundTargetNode = datasetNodesMap.get(target)

          if (foundSourceNode) {
            updatedIds.add(foundSourceNode.id)
          }

          if (foundTargetNode) {
            updatedIds.add(foundTargetNode.id)
          }
        })

        return updatedIds
      })
    },
    []
  )

  const isError = results.some((result) => result.isError)

  return { columnLineage: transformedResults, updateIds, isError }
}

function mapEdges(results: Dataset[]): ColumnLineage[] {
  return results
    .map((result) => {
      return result.fields.map((field) => {
        const target = {
          datasetId: result.id,
          fieldName: field.name
        }
        return field.sources.map((source) => ({
          source,
          target
        }))
      })
    })
    .flat(2)
}

export const useAreDatasetsLoading = (connectedEdges: Edge[]) => {
  const ids = new Set(
    connectedEdges
      .map((edge) => edge.source)
      .concat(connectedEdges.map((edge) => edge.target))
  )
  const queryClient = useQueryClient()
  const datasetQueries = queryClient
    .getQueryCache()
    .findAll({ queryKey: ['dataset'], fetchStatus: 'fetching' })
  const datasetIds = datasetQueries.map((query) => query.queryKey[1] as string)
  const isLoading = Array.from(ids).some((id) => datasetIds.includes(id))
  return isLoading
}
