"""Command-line interface for roformer-separation."""

import argparse
import os
import glob
import time
from pathlib import Path

import torch
import librosa
import soundfile as sf
import numpy as np
from tqdm.auto import tqdm

from roformer_separation.download import get_model_paths
from roformer_separation.config import load_config_from_yaml
from roformer_separation.model import MelBandRoformer
from roformer_separation.inference import demix, prefer_target_instrument


def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Separate audio sources using Roformer models",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Separate vocals from audio files in a folder
  roformer-separate --instrument vocals --input_folder /path/to/audio --store_dir /path/to/output
  
  # Separate guitar from a single audio file
  roformer-separate --instrument guitar --input_folder /path/to/song.wav --store_dir /path/to/output
"""
    )
    
    parser.add_argument(
        "--instrument",
        type=str,
        required=True,
        choices=["vocals", "guitar"],
        help="Type of instrument to separate"
    )
    
    parser.add_argument(
        "--input_folder",
        type=str,
        required=True,
        help="Path to audio file or folder containing audio files"
    )
    
    parser.add_argument(
        "--store_dir",
        type=str,
        required=True,
        help="Directory to store separated audio outputs"
    )
    
    # Hidden arguments for advanced users
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help=argparse.SUPPRESS  # Hidden from help
    )
    
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help=argparse.SUPPRESS  # Hidden from help
    )
    
    parser.add_argument(
        "--force_cpu",
        action="store_true",
        help="Force CPU usage even if GPU is available"
    )
    
    parser.add_argument(
        "--extract_instrumental",
        action="store_true",
        help="Also extract instrumental track (original minus separated instrument)"
    )
    
    parser.add_argument(
        "--use_tta",
        action="store_true",
        help="Use test-time augmentation for better quality (slower)"
    )
    
    parser.add_argument(
        "--output_format",
        type=str,
        choices=["wav", "flac"],
        default="wav",
        help="Output audio format (default: wav)"
    )
    
    return parser.parse_args()


def initialize_device(force_cpu: bool = False):
    """Initialize and return the device to use for inference."""
    if force_cpu:
        return "cpu"
    if torch.cuda.is_available():
        print('CUDA is available, using GPU.')
        return 'cuda:0'
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        print('MPS is available, using Apple Silicon GPU.')
        return 'mps'
    print('No GPU available, using CPU.')
    return "cpu"


def get_audio_files(input_path: str) -> list:
    """Get list of audio files from input path.
    
    Args:
        input_path: Path to a file or directory
        
    Returns:
        List of audio file paths
    """
    input_path = Path(input_path)
    
    # Common audio extensions
    audio_extensions = ['.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac']
    
    if input_path.is_file():
        if input_path.suffix.lower() in audio_extensions:
            return [str(input_path)]
        else:
            raise ValueError(f"File {input_path} is not a recognized audio format")
    
    elif input_path.is_dir():
        audio_files = []
        for ext in audio_extensions:
            audio_files.extend(glob.glob(str(input_path / f"*{ext}")))
            audio_files.extend(glob.glob(str(input_path / f"*{ext.upper()}")))
        
        if not audio_files:
            raise ValueError(f"No audio files found in directory {input_path}")
        
        return sorted(audio_files)
    
    else:
        raise ValueError(f"Input path {input_path} does not exist")


def load_model_and_config(checkpoint_path: str, config_path: str, device: str):
    """Load model and configuration.
    
    Args:
        checkpoint_path: Path to model checkpoint
        config_path: Path to config YAML
        device: Device to load model on
        
    Returns:
        Tuple of (model, config)
    """
    # Load config
    config = load_config_from_yaml(config_path)
    
    # Initialize model
    model = MelBandRoformer(**vars(config.model))
    
    # Load checkpoint
    print(f"Loading checkpoint: {checkpoint_path}")
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle different checkpoint formats
    if 'state' in state_dict:
        state_dict = state_dict['state']
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    
    load_res = model.load_state_dict(state_dict, strict=False)
    
    missing = getattr(load_res, 'missing_keys', [])
    unexpected = getattr(load_res, 'unexpected_keys', [])
    
    if missing:
        print(f"Warning: {len(missing)} missing keys in checkpoint")
    if unexpected:
        print(f"Warning: {len(unexpected)} unexpected keys in checkpoint")
    
    model = model.to(device)
    model.eval()
    
    return model, config


def process_audio_files(model, config, audio_files: list, store_dir: str, device: str, 
                       extract_instrumental: bool = False, use_tta: bool = False,
                       output_format: str = "wav"):
    """Process multiple audio files for separation.
    
    Args:
        model: Loaded model
        config: Model configuration
        audio_files: List of audio file paths
        store_dir: Output directory
        device: Device for inference
        extract_instrumental: Whether to extract instrumental
        use_tta: Whether to use test-time augmentation
        output_format: Output file format
    """
    os.makedirs(store_dir, exist_ok=True)
    sample_rate = getattr(config.audio, 'sample_rate', 44100)
    instruments = prefer_target_instrument(config)[:]
    
    print(f"Processing {len(audio_files)} file(s)")
    print(f"Target instruments: {instruments}")
    print(f"Output directory: {store_dir}")
    
    for path in tqdm(audio_files, desc="Processing files"):
        print(f"\nProcessing: {Path(path).name}")
        
        try:
            # Load audio
            mix, sr = librosa.load(path, sr=sample_rate, mono=False)
        except Exception as e:
            print(f'Error loading {path}: {e}')
            continue
        
        # Handle mono audio
        if len(mix.shape) == 1:
            mix = np.expand_dims(mix, axis=0)
            if getattr(config.audio, 'num_channels', 2) == 2:
                print('Converting mono to stereo...')
                mix = np.concatenate([mix, mix], axis=0)
        
        mix_orig = mix.copy()
        
        # Separate
        waveforms = demix(config, model, mix, device, model_type='mel_band_roformer', pbar=True)
        
        # Apply TTA if requested
        if use_tta:
            from roformer_separation.inference import apply_tta
            print("Applying test-time augmentation...")
            waveforms = apply_tta(config, model, mix, waveforms, device, 'mel_band_roformer')
        
        # Extract instrumental if requested
        if extract_instrumental:
            instr = 'vocals' if 'vocals' in [i.lower() for i in instruments] else instruments[0]
            # Find the correct key (case-insensitive)
            instr_key = None
            for k in waveforms.keys():
                if k.lower() == instr.lower():
                    instr_key = k
                    break
            if instr_key:
                waveforms['instrumental'] = mix_orig - waveforms[instr_key]
                if 'instrumental' not in instruments:
                    instruments.append('instrumental')
        
        # Save outputs
        file_name = Path(path).stem
        output_dir = Path(store_dir) / file_name
        output_dir.mkdir(parents=True, exist_ok=True)
        
        for instr_name, audio_data in waveforms.items():
            # Determine output format
            codec = output_format
            subtype = 'PCM_16' if output_format == 'flac' else 'FLOAT'
            
            output_path = output_dir / f"{instr_name}.{codec}"
            sf.write(str(output_path), audio_data.T, sample_rate, subtype=subtype)
        
        print(f"Saved to: {output_dir}")


def main():
    """Main entry point for the CLI."""
    args = parse_args()
    
    start_time = time.time()
    
    # Initialize device
    device = initialize_device(args.force_cpu)
    print(f"Using device: {device}")
    
    # Get model and config paths
    if args.checkpoint and args.config:
        # Use provided paths
        checkpoint_path = args.checkpoint
        config_path = args.config
        print("Using provided checkpoint and config")
    else:
        # Auto-download based on instrument
        print(f"Downloading model for instrument: {args.instrument}")
        checkpoint_path, config_path = get_model_paths(args.instrument)
    
    # Load model
    print("Loading model...")
    model, config = load_model_and_config(str(checkpoint_path), str(config_path), device)
    print("Model loaded successfully")
    
    # Get audio files
    audio_files = get_audio_files(args.input_folder)
    
    # Process audio files
    process_audio_files(
        model, 
        config, 
        audio_files, 
        args.store_dir, 
        device,
        extract_instrumental=args.extract_instrumental,
        use_tta=args.use_tta,
        output_format=args.output_format
    )
    
    elapsed = time.time() - start_time
    print(f"\n✓ Complete! Elapsed time: {elapsed:.2f} seconds")


if __name__ == "__main__":
    main()

