import { MutableRefObject } from 'react'
import * as d3 from 'd3'

import { FormattedPoint, HeatmapData } from '..'
import { formatNumber } from '../../../../helpers/rounds'
import { checkIfEventTargetIsSVGElement } from '../../../../types/common'

type DrawChartInput = {
  chartId: string
  chartRef: MutableRefObject<HTMLDivElement | null>
  data: HeatmapData<FormattedPoint>
  width: number
  displayedIntensityIndex: number
  color: string
}

export const drawChart = ({
  chartId,
  chartRef,
  data,
  width,
  displayedIntensityIndex,
  color
}: DrawChartInput) => {
  // Clear existing SVG
  d3.select(chartRef.current).selectAll('svg').remove()

  if (
    data.length === 0 ||
    displayedIntensityIndex < 0 ||
    displayedIntensityIndex >= data.length
  ) {
    return
  }

  const numberOfCategories = new Set(
    data[0].categories.flatMap((category) =>
      category.values.map((value) => value.y)
    )
  ).size

  // Set the dimensions and margins of the chart
  const margin = {
    top: 30,
    right: Math.max(20, width / 10),
    bottom: 30,
    left: Math.max(70, width / 6)
  }
  const widthWithoutMargin = width - margin.left - margin.right
  const height =
    Math.max(numberOfCategories * 70, widthWithoutMargin / 2.2) -
    margin.top -
    margin.bottom

  // Append the SVG object to the body of the page
  const svg = d3
    .select(chartRef.current)
    .append('svg')
    .attr('width', widthWithoutMargin + margin.left + margin.right)
    .attr('height', height + margin.top + margin.bottom)
    .append('g')
    .attr('transform', `translate(${margin.left},${margin.top})`)

  const xAxisValues = data[0].categories.map((d) => d.x)
  const yAxisValues = data[0].categories[0].values.map((d) => d.y)

  // Create color scale for the heatmap
  const colorScale = d3
    .scaleSequential()
    .interpolator(d3.interpolateRgb('white', color))
    .domain([
      0,
      d3.max(
        data[displayedIntensityIndex].categories.flatMap((category) =>
          category.values.map((value) => value.value)
        )
      ) ?? 0
    ])

  const tooltip = d3.select(`#${chartId}-tooltip`)
  const tooltipTitle = d3.select(`#${chartId}-tooltip-title`)

  // Create heatmap rectangles
  svg
    .selectAll('rect')
    .data(
      data[displayedIntensityIndex].categories.flatMap((category) =>
        category.values.map((value) => ({
          category: category.x,
          y: value.y,
          value: value.value
        }))
      )
    )
    .enter()
    .append('rect')
    .attr(
      'x',
      (d) =>
        xAxisValues.indexOf(d.category) *
        (widthWithoutMargin / xAxisValues.length)
    )
    .attr('y', (d) => yAxisValues.indexOf(d.y) * (height / yAxisValues.length))
    .attr('width', widthWithoutMargin / xAxisValues.length)
    .attr('height', height / yAxisValues.length)
    .style('fill', (d) => colorScale(d.value))
    .style('stroke', 'white')
    .style('stroke-width', 2)
    .on('mousemove', (e: MouseEvent, d) => {
      const hoveredElement = e.currentTarget
      if (checkIfEventTargetIsSVGElement(hoveredElement)) {
        // Calculate the center position of the hovered element
        const elementRect = hoveredElement.getBoundingClientRect()
        const elementCenterX = elementRect.left + elementRect.width / 2
        const elementCenterY = elementRect.top + elementRect.height / 2

        const datasetValuesToDisplay = data.map((dataset) => {
          const valueFromXAndY = dataset.categories
            .find((category) => category.x === d.category)
            ?.values.find((value) => value.y === d.y)?.value

          return {
            name: dataset.name,
            value:
              valueFromXAndY !== undefined
                ? dataset.unit
                  ? `${formatNumber(valueFromXAndY)} ${dataset.unit}`
                  : formatNumber(valueFromXAndY)
                : ''
          }
        })

        hoveredElement.style.opacity = '0.9'

        tooltip.style('opacity', 1)

        tooltip.style('left', `${elementCenterX}px`)
        tooltip.style('top', `${elementCenterY}px`)

        tooltipTitle.text(`${d.y} (${d.category})`)

        for (let i = 0; i < datasetValuesToDisplay.length; i++) {
          const dataset = datasetValuesToDisplay[i]
          tooltip
            .selectAll(`.${chartId}-tooltip-${dataset.name}-value`)
            .text(`${parseFloat(dataset.value) >= 0 ? dataset.value : '-'}`)
        }
      }
    })
    .on('mouseout', () => {
      tooltip.style('opacity', 0)

      svg.selectAll('rect').style('opacity', 1)
    })

  // Add X axis
  const x = d3.scaleBand().domain(xAxisValues).range([0, widthWithoutMargin])
  svg
    .append('g')
    .attr('transform', `translate(0, ${height})`)
    .attr('class', 'x-axis')
    .call(d3.axisBottom(x).tickSize(0).tickPadding(14).tickSizeOuter(0))

  // Add Y axis
  const y = d3.scaleBand().domain(yAxisValues).range([0, height])
  svg
    .append('g')
    .attr('class', 'y-axis')
    .call(d3.axisLeft(y).tickSize(0).tickPadding(10).tickArguments([5]))

  // Style the axes ticks text
  svg.selectAll('.y-axis').selectAll('text').style('font-size', '14px')
  svg.selectAll('.x-axis').selectAll('text').style('font-size', '14px')
}
