"""Download models and configs from HuggingFace."""

import os
import urllib.request
from pathlib import Path
from typing import Tuple
from tqdm.auto import tqdm


# Model and config URLs
MODEL_CONFIGS = {
    "vocals": {
        "checkpoint": "https://huggingface.co/xavriley/source_separation_mirror/resolve/main/Lead_VocalDereverb.ckpt",
        "config": "https://huggingface.co/xavriley/source_separation_mirror/resolve/main/config_karaoke_becruily.yaml",
    },
    "guitar": {
        "checkpoint": "https://huggingface.co/xavriley/source_separation_mirror/resolve/main/becruily_guitar.ckpt",
        "config": "https://huggingface.co/xavriley/source_separation_mirror/resolve/main/config_guitar_becruily.yaml",
    },
}


def get_cache_dir() -> Path:
    """Get the cache directory for storing models."""
    cache_dir = Path.home() / ".cache" / "roformer-separation"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


class DownloadProgressBar(tqdm):
    """Progress bar for downloads."""
    
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def download_file(url: str, output_path: Path) -> None:
    """Download a file with progress bar."""
    print(f"Downloading {url} to {output_path}")
    
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=output_path.name) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)


def get_model_paths(instrument: str) -> Tuple[Path, Path]:
    """Get or download model checkpoint and config for the given instrument.
    
    Args:
        instrument: Instrument name ('vocals' or 'guitar')
        
    Returns:
        Tuple of (checkpoint_path, config_path)
        
    Raises:
        ValueError: If instrument is not supported
    """
    if instrument not in MODEL_CONFIGS:
        raise ValueError(f"Unsupported instrument: {instrument}. Must be one of {list(MODEL_CONFIGS.keys())}")
    
    cache_dir = get_cache_dir()
    config = MODEL_CONFIGS[instrument]
    
    # Determine filenames
    checkpoint_url = config["checkpoint"]
    config_url = config["config"]
    
    checkpoint_filename = checkpoint_url.split("/")[-1]
    config_filename = config_url.split("/")[-1]
    
    checkpoint_path = cache_dir / checkpoint_filename
    config_path = cache_dir / config_filename
    
    # Download if not exists
    if not checkpoint_path.exists():
        try:
            download_file(checkpoint_url, checkpoint_path)
        except Exception as e:
            if checkpoint_path.exists():
                checkpoint_path.unlink()  # Clean up partial download
            raise RuntimeError(f"Failed to download checkpoint: {e}")
    else:
        print(f"Using cached checkpoint: {checkpoint_path}")
    
    if not config_path.exists():
        try:
            download_file(config_url, config_path)
        except Exception as e:
            if config_path.exists():
                config_path.unlink()  # Clean up partial download
            raise RuntimeError(f"Failed to download config: {e}")
    else:
        print(f"Using cached config: {config_path}")
    
    return checkpoint_path, config_path

