import React from 'react';
import useUiFlag from 'hooks/useFeatureFlags/useUiFlag';
import { IncrementalityDataChart, dataChart } from './MeasurementDashboard';
import { Layer, ResponsiveLine, Serie } from '@nivo/line';
import { Square } from '@phosphor-icons/react';
import { format } from 'friendly-numbers';
import { generateYTicksCustomInterval } from 'utils/Measurement/generateYTicksArray';
import { theme } from '@klover/attain-design-system';
import * as ChartStyles from './MeasurementChart.styles';
import * as Styled from './MeasurmentIncrementalityChart.styles';
import * as d3 from 'd3-shape';

const LOWYAXIS = -2;
const HIGHYAXIS = 2;
const AXISINTERVAL = 2;

interface Props {
  data: IncrementalityDataChart | null;
  dataImpressions?: dataChart[];
  ciBounds: [IncrementalityDataChart | null, IncrementalityDataChart | null];
}

export const MeasurementIncrementalityChartImpressionsAdded = ({
  data,
  dataImpressions,
  ciBounds,
}: Props) => {
  const lineTheme = {
    axis: {
      ticks: {
        text: {
          fontSize: 14,
          fill: theme.colors.textBrand, // Default color for other axes
        },
      },
    },
    grid: {
      line: {
        stroke: theme.colors.borderLight,
      },
    },
  };
  const hasCiCharting = useUiFlag('ui_incrementality_chart_ci');
  const showCiChart = hasCiCharting.isReady && hasCiCharting.enabled;
  const chartRef = React.useRef<HTMLDivElement>(null);

  const upperData =
    showCiChart && ciBounds?.at(0)?.data ? ciBounds?.at(0)?.data : [];
  const lowerData =
    showCiChart && ciBounds?.at(1)?.data ? ciBounds?.at(1)?.data : [];

  // Compute existing y-scale range
  const highYvals =
    ciBounds?.at(0)?.data?.map((i) => (typeof i.y === 'number' ? i.y : 0)) ||
    [];
  const lowYvals =
    ciBounds?.at(1)?.data?.map((i) => (typeof i.y === 'number' ? i.y : 0)) ||
    [];

  const highestY = highYvals.length ? Math.max(...highYvals) : HIGHYAXIS;
  const lowestY = lowYvals.length ? Math.min(...lowYvals) : LOWYAXIS;

  const existingYMin = lowestY;
  const existingYMax = highestY;

  const ticksY: number[] = generateYTicksCustomInterval(
    AXISINTERVAL,
    existingYMin,
    existingYMax,
    LOWYAXIS,
    HIGHYAXIS,
    0.1
  );

  // Collect all unique dates from the datasets
  const allDatesSet = new Set<string>();
  const addDatesFromData = (dataArray: { x: string }[]) => {
    dataArray.forEach((d) => {
      allDatesSet.add(d.x);
    });
  };

  if (data?.data) addDatesFromData(data.data as { x: string }[]);
  if (upperData) addDatesFromData(upperData as { x: string }[]);
  if (lowerData) addDatesFromData(lowerData as { x: string }[]);
  if (dataImpressions?.at(0)?.data)
    addDatesFromData(dataImpressions?.at(0)?.data as { x: string }[]);

  const allDates = Array.from(allDatesSet).sort((a, b) => {
    // Assuming dates are in a format parseable by Date
    return new Date(a).getTime() - new Date(b).getTime();
  });

  // Create a date-indexed map for each dataset
  const createDataMap = (dataArray: { x: string; y: number }[]) => {
    const map = new Map<string, number>();
    dataArray.forEach((d) => {
      map.set(d.x, d.y);
    });
    return map;
  };

  const incrementalityMap = createDataMap(
    (data?.data as { x: string; y: number }[]) ?? []
  );
  const upperCiMap = createDataMap(
    (upperData as { x: string; y: number }[]) ?? []
  );
  const lowerCiMap = createDataMap(
    (lowerData as { x: string; y: number }[]) ?? []
  );
  const impressionsMap = createDataMap(
    dataImpressions?.at(0)?.data?.map((d) => ({
      x: d.x,
      y: d.value,
    })) ?? []
  );

  // Scale 'Daily impressions' data
  const impressionsYvals = Array.from(impressionsMap.values());
  const impressionsMin = 0;
  const impressionsMax =
    impressionsYvals.length > 0 ? Math.max(...impressionsYvals) : 1; // Avoid division by zero

  const scaleImpressionsToY = (impressionValue: number) => {
    if (impressionsMax === impressionsMin) {
      return existingYMin;
    }
    return (
      ((impressionValue - impressionsMin) / (impressionsMax - impressionsMin)) *
        (existingYMax - existingYMin) +
      existingYMin
    );
  };

  // Build the complete datasets with all dates
  const completeIncrementalityData = allDates.map((date) => {
    const y = incrementalityMap.get(date) ?? null; // Use null for missing data if desired
    const upperY = upperCiMap.get(date) ?? null;
    const lowerY = lowerCiMap.get(date) ?? null;
    const ci =
      upperY !== null && lowerY !== null ? (upperY - lowerY) / 2 : null;
    return {
      x: date,
      y: y,
      ci: ci,
    };
  });

  const completeUpperData = allDates.map((date) => ({
    x: date,
    y: upperCiMap.get(date) ?? null,
  }));

  const completeLowerData = allDates.map((date) => ({
    x: date,
    y: lowerCiMap.get(date) ?? null,
  }));

  const completeImpressionsData = allDates.map((date) => {
    const originalY = impressionsMap.get(date);
    const isMissing = originalY === undefined || originalY === null;
    const scaledY = isMissing
      ? null // Use null to break the line
      : scaleImpressionsToY(originalY);

    return {
      x: date,
      y: scaledY,
      originalY: originalY, // Keep the original value for tooltip
      isMissing, // Flag to indicate missing data
    };
  });
  function interpolateData(data) {
    let lastValidY = null;

    return data.map((point) => {
      if (point.y !== null && point.y !== undefined) {
        lastValidY = point.y; // update last valid point
      } else if (lastValidY !== null) {
        // Fill in the missing `y` with the last valid value or some interpolation logic
        point.y = lastValidY;
      }
      return point;
    });
  }

  const processedCompleteIncrementalityData = interpolateData(
    completeIncrementalityData
  );
  const processedCompleteUpperData = interpolateData(completeUpperData);
  const processedCompleteLowerData = interpolateData(completeLowerData);

  // Prepare the chart data without 'Daily impressions'
  const chartData: Serie[] = [
    { id: 'Incrementality', data: processedCompleteIncrementalityData },
    { id: 'Upper ci', data: processedCompleteUpperData },
    { id: 'Lower ci', data: processedCompleteLowerData },
    { id: 'Daily Impressions', data: completeImpressionsData },
  ];

  // Update colors array
  const colors = [
    theme.colors.CHART_PRIMARY, // 'Incrementality'
    'hsla(254, 29%, 48%, 0.1)', // 'Upper ci'
    'hsla(254, 29%, 48%, 0.1)', // 'Lower ci'
    'transparent', // 'Daily Impressions'
  ];

  const ObfuscatedLayer = ({ xScale, innerHeight }) => {
    const x = data?.obfuscateToData ? xScale(data.obfuscateToData) : 0; // Get the scaled x position
    return (
      <rect
        x={0}
        y={0}
        width={x}
        height={innerHeight}
        fillOpacity={0.3}
        fill={theme.colors.backgroundBrand}
      />
    );
  };

  const CiLayer = ({ series, yScale }) => {
    const line1 = series.find((s) => s.id === 'Upper ci');
    const line2 = series.find((s) => s.id === 'Lower ci');
    const obfuscateIdx = line1.data.findIndex(
      (lineDat) => lineDat.data.x === data?.obfuscateToData
    );

    line1.data = line1.data.slice(obfuscateIdx > 0 ? obfuscateIdx : 0);
    line2.data = line2.data.slice(obfuscateIdx > 0 ? obfuscateIdx : 0);

    if (!line1 || !line2 || !showCiChart) {
      return null;
    }
    const areaGenerator = d3
      .area()
      .x((d: any) => d.position.x)
      .y0((d: any) => yScale(d.yBottom))
      .y1((d: any) => yScale(d.yTop))
      .curve(d3.curveMonotoneX);
    const areaData = line1.data.map((d, i) => ({
      ...d,
      yBottom: line2.data[i].data.y,
      yTop: line1.data[i].data.y,
    }));

    const d = areaGenerator(areaData);
    return (
      <path d={d ?? undefined} fill="hsla(254, 29%, 48%, 0.15)" stroke="none" />
    );
  };

  const formatTick = () => {
    return '';
  };

  const dataLen = (data?.data?.length || 0) - 1;
  const fowIdx =
    data?.data?.findIndex((datum) => datum.x === data.obfuscateToData) || 0;
  const isPoint = dataLen <= fowIdx; // if data is only a point we will enable points.

  // Custom Impressions Bar Layer
  const ImpressionsBarLayer = ({ xScale, innerHeight, innerWidth }) => {
    return (
      <g id="bars">
        {completeImpressionsData.map((datum, index) => {
          const barWidth = innerWidth / completeImpressionsData.length / 3;
          const x = xScale(datum.x);
          const barHeight =
            datum.y === 0 || datum.originalY === 0
              ? 0
              : innerHeight *
                (datum.originalY ? datum.originalY / impressionsMax : 0) *
                0.25; // make 25% of tallest height

          if (datum.y === null || datum.y === undefined) {
            return null; // Skip if no data
          }

          return (
            <rect
              key={index}
              x={x - barWidth / 2}
              y={innerHeight - barHeight}
              width={barWidth}
              height={barHeight}
              fill={theme.colors.CHART_TERTIARY}
              opacity={0.5}
            />
          );
        })}
      </g>
    );
  };

  // Extract dates from both datasets
  const liftDates = data?.data
    ? (data.data as { x: string }[]).map((d) => d.x)
    : [];
  const impressionDates = dataImpressions?.at(0)?.data
    ? dataImpressions.at(0)?.data?.map((d) => d.x)
    : [];

  // Combine the dates
  const combinedDates = [...liftDates, ...(impressionDates as string[])];

  // Compute startDate, midDate, endDate
  let startDate;
  let endDate;

  if (combinedDates.length > 0) {
    // Remove invalid dates
    const validDates = combinedDates.filter(
      (d) => !isNaN(new Date(d).getTime())
    );

    if (validDates.length > 0) {
      // Sort the dates
      validDates.sort((a, b) => new Date(a).getTime() - new Date(b).getTime());

      // Earliest date
      startDate = validDates[0];
      // Latest date
      endDate = validDates[validDates.length - 1];
    }
  }

  return (
    <Styled.Wrapper ref={chartRef}>
      <ChartStyles.BottomAxisIncrementality>
        <span>{<>{startDate}</>}</span>
        <span>{<>{data?.midDate}</>}</span>
        <span>{<>{endDate}</>}</span>
      </ChartStyles.BottomAxisIncrementality>

      <Styled.ChartWrapper>
        <ResponsiveLine
          isInteractive={true}
          axisBottom={null}
          axisLeft={{
            tickSize: 0,
            tickPadding: 18,
            format: (val) => val + '%',
            tickValues: [...ticksY],
            legend: 'LIFT',
            legendPosition: 'middle',
            legendOffset: -60,
          }}
          axisRight={{
            tickSize: 0,
            tickPadding: 18,
            format: formatTick,
            legend: 'DAILY IMPRESSIONS',
            legendPosition: 'middle',
            legendOffset: 70,
          }}
          theme={lineTheme}
          colors={colors}
          data={chartData}
          enableGridX={false}
          enableGridY={true}
          enableSlices="x"
          gridYValues={[...ticksY]}
          curve="monotoneX"
          yScale={{
            type: 'linear',
            min:
              lowestY < (ticksY.at(-1) || LOWYAXIS)
                ? lowestY
                : ticksY.at(-1) || LOWYAXIS,
            max:
              highestY > (ticksY.at(0) || HIGHYAXIS)
                ? highestY
                : ticksY.at(0) || highestY,
            stacked: false,
            reverse: false,
          }}
          margin={{
            top: 7,
            right: 75,
            bottom: 7,
            left: 65,
          }}
          pointSize={isPoint ? 8 : 0}
          sliceTooltip={(data) => {
            // Find the data points for each series using serieId
            const incPoint = data.slice.points.find(
              (point) => point.serieId === 'Incrementality'
            );
            const upperPoint = data.slice.points.find(
              (point) => point.serieId === 'Upper ci'
            );

            const date = data.slice.points[0].data.xFormatted;

            // Find the impressions data for this date
            const impressionsDatum = completeImpressionsData.find(
              (datum) => datum.x === date
            );

            const incY =
              incPoint && incPoint.data && incPoint.data.y != null
                ? Number(incPoint.data.y)
                : 0;
            const percentage = incY.toFixed(2);

            const ci =
              incPoint && incPoint.data && incPoint.data['ci'] != null
                ? incPoint.data['ci']
                : null;

            const impressionValue = impressionsDatum
              ? impressionsDatum.originalY
              : null; // Use original value
            const isMissing = impressionsDatum
              ? impressionsDatum.isMissing
              : true;

            return (
              <ChartStyles.Tooltip>
                <ChartStyles.LabelWrapper>
                  <ChartStyles.Label>{date}</ChartStyles.Label>
                </ChartStyles.LabelWrapper>
                {incPoint && (
                  <ChartStyles.LabelWrapper>
                    <Square size={18} fill={incPoint.color} weight="fill" />
                    Lift:
                    <ChartStyles.Label>{`${percentage}%`}</ChartStyles.Label>
                  </ChartStyles.LabelWrapper>
                )}
                {showCiChart && ci != null && (
                  <ChartStyles.LabelWrapper>
                    <Square size={18} fill={upperPoint?.color} weight="fill" />
                    CI delta:
                    <ChartStyles.Label>{`+/- ${ci.toFixed(
                      2
                    )}%`}</ChartStyles.Label>
                  </ChartStyles.LabelWrapper>
                )}

                <ChartStyles.LabelWrapper>
                  <Square
                    size={18}
                    fill={theme.colors.CHART_TERTIARY}
                    weight="fill"
                  />
                  Daily Impressions:
                  <ChartStyles.Label>
                    {isMissing ? '0' : Number(impressionValue).toLocaleString()}
                  </ChartStyles.Label>
                </ChartStyles.LabelWrapper>
              </ChartStyles.Tooltip>
            );
          }}
          layers={[
            'grid',
            'markers',
            'axes',
            'areas',
            'crosshair',
            ObfuscatedLayer as Layer,
            CiLayer as Layer,
            ImpressionsBarLayer as Layer, // Add the custom bar layer here
            'lines',
            'slices',
            'points',
            'mesh',
            'legends',
          ]}
        />

        <Styled.RightAxis>
          {[...Array(5)]?.map((_, i) => {
            const increment = impressionsMax;
            return <div key={i}>{format(increment * (4 - i))}</div>;
          })}
        </Styled.RightAxis>
      </Styled.ChartWrapper>
    </Styled.Wrapper>
  );
};

export default MeasurementIncrementalityChartImpressionsAdded;
