from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple, Optional
import csv
from sympy import true


class Range(NamedTuple):
    start: int
    end: int


@dataclass
class RunConfig:
    struct_image_dir: Optional[str] = None
    app_image_dir: Optional[str] = None
    struct_caption_file: Optional[str] = None
    style_caption_file: Optional[str] = None
    save_dir: Optional[str] = None
    # Appearance image path
    app_image_path: Optional[Path] = None
    # Struct image path
    struct_image_path: Optional[Path] = None
    # Domain name (e.g., buildings, animals)
    domain_name: Optional[str] = None
    # Output path
    output_path: Path = Path('./output')
    # Random seed
    seed: int = 42
    # Input prompt for inversion (will use domain name as default)
    prompt: Optional[str] = None
    # Number of timesteps
    num_timesteps: int = 100
    # Whether to use a binary mask for performing AdaIN
    use_adain: bool = True
    use_masked_adain: bool = False
    # Timesteps to apply cross-attention on 64x64 layers
    cross_attn_64_range: Range = Range(start=10, end=90)
    # Timesteps to apply cross-attention on 32x32 layers
    cross_attn_32_range: Range = Range(start=10, end=70)
    # Timesteps to apply AdaIn
    adain_range: Range = Range(start=20, end=100)
    # Swap guidance scale
    swap_guidance_scale: float = 0.0
    # Attention contrasting strength
    contrast_strength: float = 1.67
    # Object nouns to use for self-segmentation (will use the domain name as default)
    object_noun: Optional[str] = None
    # Whether to load previously saved inverted latent codes
    load_latents: bool = False
    # Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper)
    skip_steps: int = 40
    # get mask type
    save_mask: bool = True
    mask_type: str = "sam"
    mask_save_dir: str = "./temp/masks"
    # whether swap K, V in self-attention
    swap_kv: bool = True
    # start add guidance timestep
    start_guidance_timestep: int = 1000
    # feature guidance
    w_app: float = 3.0
    w_global: float = 0.9
    w_struct: float = 0.6
    w_background: float = 3.0
    pe_scale: float = 3.0
    bg_energy_scale: int = 3e3
    energy_scale: int = 1e4
    feat_guidance_type: str = "app_struct"
    bg_affine: bool = True
    # cross-attention guidance
    attention_guidance_type: str = "cross"
    cross_energy_scale: int = 10
    cross_guidance: float = 1.0
    cross_guidance_type: str = "l2"

    

    def __post_init__(self):
        self.output_path = self.output_path
            
        self.output_path.mkdir(parents=True, exist_ok=True)
        style_dict, struct_dict = {}, {}
        if self.style_caption_file is not None and self.struct_caption_file is not None:
            with open(self.style_caption_file, mode='r', newline='') as file:
                reader = csv.DictReader(file)
                for row in reader:
                    style_dict[row['name']] = row['text']
            with open(self.struct_caption_file, mode='r', newline='') as file:
                reader = csv.DictReader(file)
                for row in reader:
                    struct_dict[row['image']] = row['text']
            style_image_name = self.app_image_path.name.split("_s_")[-1]
            struct_image_name = self.struct_image_path.name.split("_s_")[0].replace("c_","")+".jpg"
            self.prompt = [struct_dict[struct_image_name], style_dict[style_image_name], struct_dict[struct_image_name]]
            
        if self.prompt is None:
            self.domain_name = [self.struct_image_path.name.split('.')[0].replace('_'," "), self.app_image_path.name.split('.')[0].replace('_'," "), self.struct_image_path.name.split('.')[0].replace("_"," ")]
            self.prompt = [f"A photo of a {domain_name}" for domain_name in self.domain_name]

            
        if self.object_noun is None:
            if not isinstance(self.domain_name, list):
                self.object_noun = [self.domain_name]
            else:
                self.object_noun = self.domain_name

        # Define the paths to store the inverted latents to
        self.latents_path = Path(self.output_path) / "latents"
        self.latents_path.mkdir(parents=True, exist_ok=True)
        self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt"
        self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt"
