import {useState, useRef, useCallback, useEffect} from 'react';
import {
  TranscribeStreamingClient,
  StartStreamTranscriptionCommand,
  StartStreamTranscriptionCommandOutput,
} from '@aws-sdk/client-transcribe-streaming';
import {CognitoIdentityClient} from '@aws-sdk/client-cognito-identity';
import {fromCognitoIdentityPool} from '@aws-sdk/credential-provider-cognito-identity';
import {AWS_REGION, AWS_TRANSCRIBE_IDENTITY_POOL_ID} from '@/config';

const SAMPLE_RATE = 16000;
const INACTIVITY_TIMEOUT = 15000;

export const useAwsTranscribe = () => {
  const [transcription, setTranscription] = useState<string>('');
  const [isTranscribing, setIsTranscribing] = useState<boolean>(false);

  const isTranscribingRef = useRef<boolean>(false);
  const audioContextRef = useRef<AudioContext | null>(null);
  const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
  const processorRef = useRef<AudioWorkletNode | null>(null);
  const audioStreamRef = useRef<MediaStream | null>(null);

  async function* audioGenerator(pcmQueue: Uint8Array[]) {
    while (isTranscribingRef.current || pcmQueue.length > 0) {
      if (pcmQueue.length > 0) {
        const chunk = pcmQueue.shift()!;
        yield {AudioEvent: {AudioChunk: chunk}};
      } else {
        await new Promise(resolve => setTimeout(resolve, 10));
      }
    }
  }

  const startTranscription = useCallback(async () => {
    setTranscription('');
    setIsTranscribing(true);
    isTranscribingRef.current = true;

    try {
      const audioStream: MediaStream =
        await navigator.mediaDevices.getUserMedia({
          audio: true,
        });
      audioStreamRef.current = audioStream;

      const audioContext: AudioContext = new (window.AudioContext ||
        // @ts-expect-error it's a WebkitAudioContext
        window.webkitAudioContext)({sampleRate: SAMPLE_RATE});
      audioContextRef.current = audioContext;

      if (audioContext.sampleRate !== SAMPLE_RATE) {
        console.warn(
          `Sample rate is ${audioContext.sampleRate}, expected ${SAMPLE_RATE}.`,
        );
      }

      // Load the AudioWorklet module
      const processorBlob = new Blob(
        [
          `
          class AudioProcessor extends AudioWorkletProcessor {
            process(inputs, outputs, parameters) {
              const input = inputs[0];
              if (input && input.length > 0) {
                const channelData = input[0];
                this.port.postMessage(channelData);
              }
              return true;
            }
          }

          registerProcessor('audio-processor', AudioProcessor);
          `,
        ],
        {type: 'application/javascript'},
      );
      const processorUrl = URL.createObjectURL(processorBlob);
      await audioContext.audioWorklet.addModule(processorUrl);

      // Create the AudioWorkletNode
      const audioWorkletNode = new AudioWorkletNode(
        audioContext,
        'audio-processor',
      );
      processorRef.current = audioWorkletNode;

      // Create the MediaStreamAudioSourceNode
      const source = audioContext.createMediaStreamSource(audioStream);
      sourceRef.current = source;

      source.connect(audioWorkletNode);
      audioWorkletNode.connect(audioContext.destination);

      const pcmQueue: Uint8Array[] = [];

      audioWorkletNode.port.onmessage = (event: MessageEvent) => {
        if (!isTranscribingRef.current) {
          return;
        }

        const inputData: Float32Array = event.data;
        const pcmData: Uint8Array = convertFloat32ToInt16(inputData);

        pcmQueue.push(pcmData);
      };

      const transcribeClient = new TranscribeStreamingClient({
        region: AWS_REGION,
        credentials: fromCognitoIdentityPool({
          client: new CognitoIdentityClient({region: AWS_REGION}),
          identityPoolId: AWS_TRANSCRIBE_IDENTITY_POOL_ID,
        }),
      });

      const command = new StartStreamTranscriptionCommand({
        LanguageCode: 'en-US',
        MediaEncoding: 'pcm',
        MediaSampleRateHertz: audioContext.sampleRate,
        AudioStream: audioGenerator(pcmQueue),
      });

      const response: StartStreamTranscriptionCommandOutput =
        await transcribeClient.send(command);
      let transcriptSoFar = '';

      if (response.TranscriptResultStream) {
        for await (const event of response.TranscriptResultStream) {
          if (!isTranscribingRef.current) {
            break;
          }

          if (event.TranscriptEvent) {
            const results = event.TranscriptEvent.Transcript?.Results;

            if (results && results.length > 0) {
              const result = results[0];
              const transcript = result.Alternatives?.[0].Transcript ?? '';

              if (!result.IsPartial) {
                // Final result, append to transcriptSoFar
                transcriptSoFar += transcript + ' ';
                setTranscription(transcriptSoFar);
              } else {
                // Partial result, update display without appending to transcriptSoFar
                setTranscription(transcriptSoFar + transcript);
              }
            }
          }
        }
      }
    } catch (err) {
      console.error('Error starting transcription:', err);
    }
  }, []);

  const stopTranscription = useCallback(() => {
    setIsTranscribing(false);
    isTranscribingRef.current = false;

    // Clean up audio context and stream
    if (processorRef.current) {
      processorRef.current.port.close();
      processorRef.current.disconnect();
    }
    if (sourceRef.current) {
      sourceRef.current.disconnect();
    }
    if (audioContextRef.current) {
      audioContextRef.current.close();
    }
    if (audioStreamRef.current) {
      audioStreamRef.current.getTracks().forEach(track => track.stop());
    }
  }, []);

  // stop transcription after inactivity timeout
  useEffect(() => {
    let timeout: any;

    if (isTranscribing) {
      timeout = setTimeout(() => stopTranscription(), INACTIVITY_TIMEOUT);
    }

    return () => {
      clearTimeout(timeout);
    };
  }, [isTranscribing, transcription, stopTranscription]);

  return {transcription, isTranscribing, startTranscription, stopTranscription};
};

function convertFloat32ToInt16(buffer: Float32Array): Uint8Array {
  const l = buffer.length;
  const buf = new Int16Array(l);
  for (let i = 0; i < l; i++) {
    const s = Math.max(-1, Math.min(1, buffer[i]));
    buf[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
  }
  return new Uint8Array(buf.buffer);
}
