"""Inference utilities for audio source separation."""

import numpy as np
import torch
import torch.nn as nn
from types import SimpleNamespace
from tqdm.auto import tqdm


# -----------------------------
# Audio utility functions
# -----------------------------
def normalize_audio(audio: np.ndarray):
    """Normalize audio to zero mean and unit variance."""
    mono = audio.mean(0)
    mean, std = mono.mean(), mono.std()
    return (audio - mean) / (std if std != 0 else 1.0), {"mean": mean, "std": std if std != 0 else 1.0}


def denormalize_audio(audio: np.ndarray, norm_params):
    """Denormalize audio using provided parameters."""
    return audio * norm_params["std"] + norm_params["mean"]


# -----------------------------
# Inference functions
# -----------------------------
def _get_windowing_array(window_size: int, fade_size: int) -> torch.Tensor:
    """Create windowing array with fade in/out."""
    fadein = torch.linspace(0, 1, fade_size)
    fadeout = torch.linspace(1, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] = fadeout
    window[:fade_size] = fadein
    return window


def prefer_target_instrument(config) -> list:
    """Get the target instrument(s) from config."""
    target = getattr(getattr(config, 'training', SimpleNamespace()), 'target_instrument', None)
    if target:
        return [target]
    return list(getattr(config.training, 'instruments', []))


def demix(config, model: torch.nn.Module, mix: np.ndarray, device: str, model_type: str, pbar: bool = False):
    """Separate audio into stems using the model.
    
    Args:
        config: Model configuration
        model: Trained model
        mix: Input audio as numpy array (channels, samples)
        device: Device to run inference on
        model_type: Type of model (e.g. 'mel_band_roformer')
        pbar: Whether to show progress bar
        
    Returns:
        Dictionary mapping instrument names to separated audio arrays
    """
    mix = torch.tensor(mix, dtype=torch.float32)
    
    # mel_band_roformer uses generic mode
    chunk_size = getattr(config.inference, 'chunk_size', getattr(config.audio, 'chunk_size', mix.shape[-1]))
    num_instruments = len(prefer_target_instrument(config))
    num_overlap = config.inference.num_overlap
    fade_size = chunk_size // 10
    step = chunk_size // num_overlap
    border = chunk_size - step
    length_init = mix.shape[-1]
    windowing_array = _get_windowing_array(chunk_size, fade_size)
    if length_init > 2 * border and border > 0:
        mix = nn.functional.pad(mix, (border, border), mode="reflect")

    batch_size = config.inference.batch_size
    use_amp = getattr(config.training, 'use_amp', True)

    with torch.cuda.amp.autocast(enabled=use_amp):
        with torch.inference_mode():
            req_shape = (num_instruments,) + mix.shape
            result = torch.zeros(req_shape, dtype=torch.float32)
            counter = torch.zeros(req_shape, dtype=torch.float32)

            i = 0
            batch_data = []
            batch_locations = []
            progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None

            while i < mix.shape[1]:
                part = mix[:, i:i + chunk_size].to(device)
                chunk_len = part.shape[-1]
                pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
                part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0)
                batch_data.append(part)
                batch_locations.append((i, chunk_len))
                i += step

                if len(batch_data) >= batch_size or i >= mix.shape[1]:
                    arr = torch.stack(batch_data, dim=0)
                    x = model(arr)
                    window = windowing_array.clone()
                    for j, (start, seg_len) in enumerate(batch_locations):
                        if i - step == 0:
                            window[:fade_size] = 1
                        elif i >= mix.shape[1]:
                            window[-fade_size:] = 1
                        result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len]
                        counter[..., start:start + seg_len] += window[..., :seg_len]
                    batch_data.clear()
                    batch_locations.clear()
                if progress_bar:
                    progress_bar.update(step)
            if progress_bar:
                progress_bar.close()

            estimated_sources = result / counter
            estimated_sources = estimated_sources.cpu().numpy()
            np.nan_to_num(estimated_sources, copy=False, nan=0.0)
            if length_init > 2 * border and border > 0:
                estimated_sources = estimated_sources[..., border:-border]

    instruments = prefer_target_instrument(config)
    ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
    return ret_data


def apply_tta(config, model: torch.nn.Module, mix: np.ndarray, waveforms_orig: dict, device: str, model_type: str):
    """Apply test-time augmentation (TTA) for improved results.
    
    Args:
        config: Model configuration
        model: Trained model
        mix: Input audio
        waveforms_orig: Original separated waveforms
        device: Device to run inference on
        model_type: Type of model
        
    Returns:
        Dictionary with TTA-averaged waveforms
    """
    track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()]
    for i, augmented_mix in enumerate(track_proc_list):
        waveforms = demix(config, model, augmented_mix, device, model_type=model_type)
        for el in waveforms:
            if i == 0:
                waveforms_orig[el] += waveforms[el][::-1].copy()
            else:
                waveforms_orig[el] -= waveforms[el]
    for el in waveforms_orig:
        waveforms_orig[el] /= len(track_proc_list) + 1
    return waveforms_orig

