import enum
import dataclasses
import pandas
import os
import ast
import os.path
import gzip
import numpy as np

# all indices
# id,full_name,first_name,last_name,linkedin_url,linkedin_username,linkedin_id,facebook_url,facebook_username,facebook_id,mobile_phone,industry,job_title,job_title_role,job_title_levels,job_company_id,job_company_name,job_company_website,job_company_size,job_company_founded,job_company_industry,job_company_linkedin_url,job_company_linkedin_id,job_company_facebook_url,job_company_twitter_url,job_company_location_name,job_company_location_country,job_company_location_continent,job_last_updated,job_start_date,job_summary,location_name,location_locality,location_region,location_country,location_continent,location_geo,location_last_updated,linkedin_connections,inferred_salary,inferred_years_experience,summary,phone_numbers,emails,interests,skills,location_names,regions,countries,street_addresses,experience,education,profiles,certifications,languages,version_status,gender,work_email,job_title_sub_role,location_metro,middle_initial,birth_year,birth_date,job_company_location_locality,job_company_location_metro,job_company_location_region,job_company_location_geo,job_company_location_street_address,job_company_location_postal_code,location_street_address,location_postal_code,middle_name,location_address_line_2,job_company_location_address_line_2,twitter_url,twitter_username,github_url,github_username


class QueryFormatTypes(enum.Enum):
    basic = "basic"

class QueryFormat():
    def get_info_block(self) -> str:
        raise NotImplementedError

    def get_full_info(self) -> str:
        raise NotImplementedError

    def get_summary_prompt(self, summary_prefix: str) -> str:
        return f"{self.get_info_block()}\nSummary: {summary_prefix}"

    @classmethod
    def target_indices(cls) -> list[str]:
        raise NotImplementedError

@dataclasses.dataclass
class BasicQueryFormat(QueryFormat):
    full_name : str
    gender: str
    industry: str
    job_title: str
    job_company_size: str
    inferred_salary: str
    skills: list[str]
    summary: str

    @classmethod
    def target_indices(cls) -> list[str]:
        return [
            "id","full_name",
            "gender", "industry", "job_title", "job_company_size",
            "inferred_salary", "skills", "summary"
        ]
    def get_info_block(self) -> str:
        return f"""Full Name: {self.full_name}
Gender: {self.gender}
Industry: {self.industry}
Job Title: {self.job_title}
Company Size: {self.job_company_size}
Inferred Salary: {self.inferred_salary}
Skills: {", ".join(self.skills)}
"""

    def get_full_info(self) -> str:
        return f"{self.get_info_block()}\nSummary: {self.summary}"


@dataclasses.dataclass
class FiveFactorQueryFormat(QueryFormat):
    full_name : str
    gender: str
    industry: str
    job_title: str
    job_company_size: str
    inferred_salary: str
    five_factor_scores: dict[str, float]
    skills: list[str]
    summary: str

    @classmethod
    def target_indices(cls) -> list[str]:
        return [
            "id","full_name",
            "gender", "industry", "job_title", "job_company_size",
            "inferred_salary", "skills", "five_factor_scores","summary"
        ]
    def get_info_block(self) -> str:
        personality_scores = {
            "Openness": self.five_factor_scores['openness'],
            "Conscientiousness": self.five_factor_scores['conscientiousness'],
            "Extraversion": self.five_factor_scores['extraversion'],
            "Agreeableness": self.five_factor_scores['agreeableness'],
            "Neuroticism": self.five_factor_scores['neuroticism'],
        }
        personality_str = "\n".join(
            f"   {trait}: {score:.2f}" for trait, score in personality_scores.items()
        )
        return f"""Full Name: {self.full_name}
Gender: {self.gender}
Industry: {self.industry}
Job Title: {self.job_title}
Company Size: {self.job_company_size}
Inferred Salary: {self.inferred_salary}
Five Factor Personality Scores:
{personality_str}
Skills: {", ".join(self.skills)}
"""

    def get_full_info(self) -> str:
        return f"{self.get_info_block()}\nSummary: {self.summary}"


def parse_skills(raw_str: str) -> list[str]:
    """Safely parse a string representation of a list of strings back into a list.
    """
    if not raw_str or raw_str == '[]':
        return []

    try:
        # Use ast.literal_eval to safely parse the string representation
        return ast.literal_eval(raw_str)
    except (SyntaxError, ValueError):
        # Fallback for malformed strings
        # Strip brackets and split by commas
        cleaned = raw_str.strip('[]').split(',')
        return [s.strip().strip("'\"") for s in cleaned if s.strip()]


def convert_li_to_basic_training_data(
    file_path: str,
    output_path: str,
    train_fraction: float,
    val_fraction: float,
    test_fraction: float,
) -> None:
    df = pandas.read_csv(file_path)
    for target_col in BasicQueryFormat.target_indices():
        df = df[df[target_col].notna()]
    df = df[BasicQueryFormat.target_indices()]
    df["skills"] = df["skills"].apply(parse_skills)

    rng = np.random.default_rng(seed=abs(hash(df.iloc[0]['id'])))

    df['train_test_val'] =  rng.choice(
        [0,1,2],
        p=[train_fraction, val_fraction, test_fraction],
        size=len(df)
    )

    if os.path.exists(output_path):
        os.remove(output_path)

    with gzip.open(output_path, "wt", encoding="utf-8") as f:
        df.to_json(
            f,
            orient="records",
            lines=True,
            force_ascii=False,
            date_format="iso",
            double_precision=10,
        )

