"""Configuration management for roformer models."""

from types import SimpleNamespace
from typing import Optional
import yaml


def _tuple_constructor(loader, node):
    """Custom constructor to handle !!python/tuple tags in YAML."""
    return tuple(loader.construct_sequence(node))


# Register the custom constructor for python/tuple tags
yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', _tuple_constructor)


def _dict_to_namespace(d: dict):
    """Convert nested dict to SimpleNamespace."""
    def convert(obj):
        if isinstance(obj, dict):
            return SimpleNamespace(**{k: convert(v) for k, v in obj.items()})
        elif isinstance(obj, (list, tuple)):
            return type(obj)(convert(v) for v in obj)
        else:
            return obj
    return convert(d)


def load_config_from_yaml(config_path: str):
    """Load config from YAML file.
    
    Args:
        config_path: Path to YAML config file
        
    Returns:
        SimpleNamespace config object
    """
    with open(config_path, 'r') as f:
        data = yaml.safe_load(f)
    return _dict_to_namespace(data)


def build_default_config():
    """Build default config based on config_karaoke_becruily.yaml"""
    cfg = {
        "audio": {
            "chunk_size": 485100,
            "dim_f": 1024,
            "dim_t": 256,
            "hop_length": 441,
            "n_fft": 2048,
            "num_channels": 2,
            "sample_rate": 44100,
            "min_mean_abs": 0.000,
        },
        "model": {
            "dim": 384,
            "depth": 6,
            "stereo": True,
            "num_stems": 2,
            "time_transformer_depth": 1,
            "freq_transformer_depth": 1,
            "num_bands": 60,
            "dim_head": 64,
            "heads": 8,
            "attn_dropout": 0,
            "ff_dropout": 0,
            "flash_attn": True,
            "dim_freqs_in": 1025,
            "sample_rate": 44100,
            "stft_n_fft": 2048,
            "stft_hop_length": 441,
            "stft_win_length": 2048,
            "stft_normalized": False,
            "mask_estimator_depth": 2,
            "multi_stft_resolution_loss_weight": 1.0,
            "multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
            "multi_stft_hop_size": 147,
            "multi_stft_normalized": False,
        },
        "training": {
            "batch_size": 1,
            "gradient_accumulation_steps": 1,
            "grad_clip": 0,
            "instruments": ["Vocals", "Instrumental"],
            "target_instrument": None,
            "use_amp": True,
        },
        "inference": {
            "batch_size": 1,
            "dim_t": 1101,
            "num_overlap": 8,
        },
    }
    return _dict_to_namespace(cfg)

