import dataclasses
import torch
import torch.nn
import torch.nn.functional
import typing
import json
import enum
import numpy as np
import uuid
import os.path

import sentence_transformers
from .llm_models import LIModelWrapper
from .query_format import BasicQueryFormat
import glob
from .fairness_correction import get_corrected_score

_EMBED_MODEL: None | sentence_transformers.SentenceTransformer = None
EMBED_DIM = 384
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"


class SkillQuestionSet:
    def __init__(
        self,
        skills: list[str],
    ) -> None:
        self.skills = skills
        self.skill_to_idx = {skill: i for i, skill in enumerate(self.skills)}
        self.idx_to_skill = {i: skill for i, skill in enumerate(self.skills)}

    def __len__(self) -> int:
        return len(self.skills)

    def __repr__(self) -> str:
        return f"SkillQuestionSet({','.join(self.skills)})"

    @classmethod
    def load_from_json(cls, json_path: str) -> typing.Self:
        with open(json_path, "r") as f:
            skills = json.load(f)
        if not isinstance(skills, list):
            raise ValueError(f"Expected list of skills, got {type(skills)}")
        if not all(isinstance(skill, str) for skill in skills):
            raise ValueError(f"Expected list of strings, got {type(skills)}")
        if len(skills) != len(set(skills)):
            raise ValueError(f"Duplicate skills found in {json_path}")
        return cls(skills)


class LIGender(enum.Enum):
    male = "male"
    female = "female"


@dataclasses.dataclass
class FiveFactors:
    openness: float
    conscientiousness: float
    extraversion: float
    agreeableness: float
    neuroticism: float

    def to_int_dict(
        self,
    ) -> dict[int, float]:
        return {
            0: self.openness,
            1: self.conscientiousness,
            2: self.extraversion,
            3: self.agreeableness,
            4: self.neuroticism,
        }

    @classmethod
    def from_dict(
        cls,
        json_dict: dict[str, float],
    ) -> typing.Self:
        return cls(
            openness=json_dict["openness"],
            conscientiousness=json_dict["conscientiousness"],
            extraversion=json_dict["extraversion"],
            agreeableness=json_dict["agreeableness"],
            neuroticism=json_dict["neuroticism"],
        )


@dataclasses.dataclass
class Candidate:
    name: str
    gender: LIGender
    industry: str
    job_title: str
    job_company_size: str
    inferred_salary: str
    skills: list[str]
    summary: str | None
    five_factors: FiveFactors
    id: str
    cache_file: str | None

    @classmethod
    def from_json_dict(
        cls,
        json_dict: dict[str, str | list[str] | str | dict[str, float]],
        cache_dir: str | None,
        cache_file: str | None,
    ) -> typing.Self:
        match json_dict["gender"]:
            case "male" | LIGender.male:
                gender = LIGender.male
            case "female" | LIGender.female:
                gender = LIGender.female
            case _:
                raise ValueError(
                    f"I didn't think there were non-binary candidates in the dataset, got {json_dict}"
                )

        raw_five_factors = json_dict["five_factors"]
        assert isinstance(raw_five_factors, dict), (
            f"Expected dict, got {type(raw_five_factors)}"
        )
        five_factors = FiveFactors.from_dict(raw_five_factors)

        if cache_file is not None:
            cache_file_path = cache_file
        elif cache_dir is not None:
            cache_file_path = os.path.join(
                cache_dir, "real_candidates", f"{json_dict['id']}.json"
            )
        else:
            cache_file_path = None

        return cls(
            name=str(json_dict["name"]),
            gender=gender,
            industry=str(json_dict["industry"]),
            job_title=str(json_dict["job_title"]),
            job_company_size=str(json_dict["job_company_size"]),
            inferred_salary=str(json_dict["inferred_salary"]),
            skills=list(json_dict["skills"]),
            summary=str(json_dict["summary"]),
            five_factors=five_factors,
            id=str(json_dict["id"]),
            cache_file=cache_file_path,
        )

    def basic_skill_query(
        self,
        query_skills: list[str],
        summary_prefix: str,
        model: LIModelWrapper,
        max_tokens: int = 200,
    ) -> "Candidate":
        # get skills not already present
        new_skills = [s for s in query_skills if s not in self.skills]
        if len(new_skills) == 0 and self.summary is not None:
            return self

        query = BasicQueryFormat(
            full_name=self.name,
            gender=self.gender.value,
            industry=self.industry,
            job_title=self.job_title,
            job_company_size=self.job_company_size,
            inferred_salary=self.inferred_salary,
            skills=self.skills + list(new_skills),
            summary="",
        )
        model_response = model.run_basic_summary(
            query,
            summary_prefix=summary_prefix,
            max_new_tokens=max_tokens,
            do_sample=True,
        )
        new_candidate = Candidate(
            name=self.name,
            gender=self.gender,
            industry=self.industry,
            job_title=self.job_title,
            job_company_size=self.job_company_size,
            inferred_salary=self.inferred_salary,
            skills=self.skills + list(new_skills),
            five_factors=self.five_factors,
            # summary=model_response.text,
            summary=model_response.text,
            id=uuid.uuid4().hex,
            cache_file=self.cache_file,
        )
        return new_candidate


class CandidateSet:
    def __init__(
        self,
        full_candidates: list[Candidate],
        name_doubles: list[tuple[str, LIGender]],
        industry_tuples: list[tuple[str, str, str, str, list[str]]],
        five_factors: list[dict[str, float]],
        cache_dir: str | None,
    ) -> None:
        self.full_candidates = full_candidates
        self.full_candidates_index_array = np.array(range(len(full_candidates)))
        self.name_doubles = name_doubles  # name, gender,
        self.name_doubles_index_array = np.array(range(len(name_doubles)))
        self.industry_tuples = industry_tuples  # industry, job_title, job_company_size, inferred_salary, skills
        self.industry_tuples_index_array = np.array(range(len(industry_tuples)))
        self.five_factors_index_array = np.array(range(len(five_factors)))
        self.five_factors = [FiveFactors.from_dict(fd) for fd in five_factors]

        self.cache_dir = cache_dir

        self.real_cache = []
        self.synthetic_cache = []

        if cache_dir is not None:
            os.makedirs(f"{cache_dir}/real_candidates", exist_ok=True)
            os.makedirs(f"{cache_dir}/synthetic_candidates", exist_ok=True)
            self.real_cache = glob.glob(f"{cache_dir}/real_candidates/*.json")
            self.synthetic_cache = glob.glob(f"{cache_dir}/synthetic_candidates/*.json")

    @classmethod
    def load_from_json(
        cls,
        json_path: str,
        cache_dir: str | None,
    ) -> typing.Self:
        with open(json_path, "r") as f:
            data = json.load(f)
        if not isinstance(data, dict):
            raise ValueError(f"Expected dict, got {type(data)}")
        if "full_candidates" not in data:
            raise ValueError(f"Missing full_candidates in {json_path}")
        if "name_doubles" not in data:
            raise ValueError(f"Missing name_doubles in {json_path}")
        if "industry_tuples" not in data:
            raise ValueError(f"Missing industry_tuples in {json_path}")
        assert len(data) == 4, f"Expected 4 fields, got {len(data)}"

        candidates = []
        for raw_c in data["full_candidates"]:
            if not isinstance(raw_c, dict):
                raise ValueError(f"Expected dict, got {type(raw_c)}")
            raw_c["gender"] = (
                LIGender.male if raw_c["gender"] == "male" else LIGender.female
            )
            candidates.append(
                Candidate.from_json_dict(raw_c, cache_dir, cache_file=None)
            )
        return cls(
            full_candidates=candidates,
            name_doubles=[
                (d[0], LIGender.male if d[1] == "male" else LIGender.female)
                for d in data["name_doubles"]
            ],
            industry_tuples=data["industry_tuples"],
            five_factors=data["five_factors_sets"],
            cache_dir=cache_dir,
        )

    def get_candidate_file_path(
        self,
        candidate_id: str,
        is_real: bool,
    ) -> str | None:
        if self.cache_dir is None:
            return None
        if is_real:
            return os.path.join(
                self.cache_dir, "real_candidates", f"{candidate_id}.json"
            )
        else:
            return os.path.join(
                self.cache_dir, "synthetic_candidates", f"{candidate_id}.json"
            )


    def sample_candidate(
        self, rng: np.random.Generator, real_prob: float = 0.5
    ) -> Candidate:
        if rng.random() < real_prob:
            if len(self.real_cache) > 0:
                return self.load_from_cache(self.real_cache, rng)
            # Sample from full_candidates
            candidate_id = rng.choice(self.full_candidates_index_array)
            return self.full_candidates[candidate_id]
        else:
            if len(self.synthetic_cache) > 0:
                return self.load_from_cache(self.synthetic_cache, rng)
            name_id = rng.choice(self.name_doubles_index_array)

            industry_id = rng.choice(self.industry_tuples_index_array)
            five_factor_id = rng.choice(self.five_factors_index_array)
            _id = uuid.uuid4().hex
            return Candidate(
                name=self.name_doubles[name_id][0],
                gender=self.name_doubles[name_id][1],
                industry=self.industry_tuples[industry_id][0],
                job_title=self.industry_tuples[industry_id][1],
                job_company_size=self.industry_tuples[industry_id][2],
                inferred_salary=self.industry_tuples[industry_id][3],
                skills=self.industry_tuples[industry_id][4],
                summary=None,
                five_factors=self.five_factors[five_factor_id],
                id=_id,
                cache_file=self.get_candidate_file_path(
                    candidate_id=_id, is_real=False
                ),
            )


@dataclasses.dataclass
class InterviewNode:
    parent: "InterviewNode | None"
    children: dict[str, "InterviewNode"]
    question: str | None
    question_id: int | None
    num_questions: int
    candidate: Candidate

    _node_tensor: torch.Tensor | None

    def get_question_sequence(
        self,
        new_question: str,
    ) -> list[str]:
        questions = [new_question]
        current_node = self
        while current_node.parent is not None:
            if current_node.parent.question is None:
                assert current_node.parent.question_id == 0, (
                    "question_id must be 0 for root node"
                )
            else:
                questions.append(current_node.parent.question)
            current_node = current_node.parent
        questions.reverse()
        return questions

    def ask_root_summary(
        self,
        summary_prefix: str,
        model: LIModelWrapper,
        max_tokens: int = 200,
    ) -> "InterviewNode":
        assert self.parent is None, (
            "ask_root_summary can only be called on the root node"
        )
        assert self.question is None, (
            "ask_root_summary can only be called on the root node"
        )
        assert self.question_id is None, (
            "ask_root_summary can only be called on the root node"
        )
        if self.candidate.summary is None:
            response_candidate = self.candidate.basic_skill_query(
                [],
                summary_prefix,
                model,
                max_tokens=max_tokens,
            )
            self.candidate.summary = response_candidate.summary
            self.question_id = 0
        else:
            self.question_id = 0
        return self

    def ask_question(
        self,
        question: str,
        question_id: int,
        num_questions: int,
        summary_prefix: str,
        model: LIModelWrapper,
        max_tokens: int = 200,
    ) -> "InterviewNode":
        assert question_id != 0, "question_id must be greater than 0"

        if question not in self.children:
            response_candidate = self.candidate.basic_skill_query(
                self.get_question_sequence(question),
                summary_prefix,
                model,
                max_tokens=max_tokens,
            )
            new_node = InterviewNode(
                parent=self,
                children={},
                question=question,
                question_id=question_id,
                num_questions=num_questions,
                candidate=response_candidate,
                _node_tensor=None,
            )
            self.children[question] = new_node
        return self.children[question]

    def get_root(self) -> "InterviewNode":
        current_node = self
        while current_node.parent is not None:
            current_node = current_node.parent
        return current_node

    def get_y(
        self,
        target_skill: str,
        device: str,
    ) -> torch.Tensor:
        score_label = 1 if target_skill in self.get_root().candidate.skills else 0
        score_labels_tensor = torch.tensor([score_label], device=device)
        score_labels_tensor_onehot = torch.nn.functional.one_hot(
            score_labels_tensor,
            num_classes=2,
        )
        return score_labels_tensor_onehot.squeeze(0).float()

    def get_x(
        self,
        output_dim: int,
        device: str,
    ) -> torch.Tensor:
        return self.get_node_tensor(
            output_dim=output_dim,
            device=device,
        )

    def get_node_tensor(
        self,
        output_dim: int,
        device: str,
    ) -> torch.Tensor:
        if self._node_tensor is None:
            answer_vectors = []
            current_node = self
            while current_node is not None:
                assert current_node.candidate is not None, "Candidate must be set"
                assert current_node.candidate.summary is not None, (
                    "Candidate summary must be set"
                )
                summary_vector = get_embedding(
                    current_node.candidate.summary,
                )

                assert current_node.question_id is not None, "Question must be set"
                # one-hot encoding of question_id
                question_vector = (
                    torch.nn.functional.one_hot(
                        torch.tensor([current_node.question_id]),
                        num_classes=self.num_questions + 1,
                    )
                    .float()
                    .to(summary_vector.device)
                )
                answer_vectors.append(
                    torch.cat(
                        [
                            summary_vector,
                            question_vector.squeeze(0),
                        ],
                        dim=-1,
                    )
                )
                current_node = current_node.parent
            answer_vectors.reverse()
            self._node_tensor = torch.stack(answer_vectors, dim=0).to(device)
            if self._node_tensor.shape[1] != output_dim:
                # pad to output_dim
                padding = torch.zeros(
                    (
                        self._node_tensor.shape[0],
                        output_dim - self._node_tensor.shape[1],
                    ),
                    device=device,
                )
                self._node_tensor = torch.cat(
                    [
                        self._node_tensor,
                        padding,
                    ],
                    dim=1,
                )
        return self._node_tensor

    def previous_questions(self) -> list[int]:
        questions = []
        current_node = self
        while current_node.parent is not None:
            if current_node.parent.question is None:
                assert current_node.parent.question_id == 0, (
                    "question_id must be 0 for root node"
                )
            else:
                questions.append(current_node.parent.question_id)
            current_node = current_node.parent
        questions.reverse()
        return questions


class AttentionModule(torch.nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        num_layers: int,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.dim = dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.head_dim = dim // num_heads

        # Create a stack of transformer encoder layers
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=num_heads,
            dim_feedforward=dim * 4,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=num_layers
        )

        # Optional attention weights projection
        self.attn_weights_proj = torch.nn.Linear(dim, 1)

    def forward(
        self,
        input_sequence: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # input_sequence shape: [batch_size, sequence_length, dim]

        # Apply transformer encoder
        # output shape: [batch_size, sequence_length, dim]
        output = self.transformer_encoder(input_sequence)

        # Generate attention weights (normalized across sequence dimension)
        # weights shape: [batch_size, sequence_length, 1]
        attention_logits = self.attn_weights_proj(output)
        attention_weights = torch.softmax(attention_logits, dim=1)

        return output, attention_weights.squeeze(-1)


class EENNBase(torch.nn.Module):
    def save_to_file(
        self,
        file_path: str,
    ) -> None:
        """Saves the model to a file."""
        torch.save(self.state_dict(), file_path)

    def load_from_file(
        self,
        file_path: str,
        device: str = "gpu" if torch.cuda.is_available() else "cpu",
    ) -> None:
        """Loads the model from a file."""
        self.load_state_dict(torch.load(file_path, map_location=device))
        self.eval()


@dataclasses.dataclass
class question_dist:
    top_question_int: int
    question_dist_int: dict[int, float]
    question_set: SkillQuestionSet

    @property
    def top_question(self) -> str:
        return self.question_set.idx_to_skill[self.top_question_int]

    @property
    def question_dist(self) -> dict[str, float]:
        return {
            self.question_set.idx_to_skill[k]: v
            for k, v in self.question_dist_int.items()
        }


class QuestionSelector(EENNBase):
    def __init__(
        self,
        question_set: SkillQuestionSet,
    ) -> None:
        super().__init__()

        input_dim = EMBED_DIM + len(question_set) + 1
        # round up to nearest multiple of 8
        self.input_dim = (input_dim + 7) // 8 * 8

        self.attention = AttentionModule(
            dim=self.input_dim,
            num_heads=8,
            num_layers=2,
        )
        self.linear = torch.nn.Linear(self.input_dim, len(question_set))
        self.softmax = torch.nn.Softmax(dim=-1)
        self.question_set = question_set

    def forward(
        self,
        response_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        # response_embeddings: [batch_size, num_responses, embed_dim]
        # attention_output: [batch_size, num_responses, embed_dim]
        attention_output, _ = self.attention(
            response_embeddings,
        )

        # logits: [batch_size, num_responses, num_questions]
        logits = self.linear(attention_output)

        # probabilities: [batch_size, num_responses, num_questions]
        probabilities = self.softmax(logits)

        return probabilities

    def get_questions_for_node(
        self,
        node: InterviewNode,
        action_set: SkillQuestionSet,
        device: str,
    ) -> question_dist:
        assert node is not None, "No questions for the parent node"
        node_tensor = node.get_node_tensor(
            self.input_dim,
            device=device,
        )
        # node_tensor: [embed_dim + num_questions + 1]
        resp = self(node_tensor)
        # resp: [-1, num_questions]

        dist = resp.detach().cpu().numpy()
        question_dist_int = {i: float(dist[-1, i]) for i in range(dist.shape[1])}
        return question_dist(
            top_question_int=sorted(
                question_dist_int.items(),
                key=lambda x: x[1],
                reverse=True,
            )[0][0],
            question_dist_int=question_dist_int,
            question_set=action_set,
        )


@dataclasses.dataclass
class score_dist:
    score_true: float
    score_false: float

    def to_z_score(self) -> float:
        return abs(self.score_true - 0.5) * 2

    def get_corrected_score(
        self,
        class_weights: dict[int, float],
        multipliers: dict[int, float],
    ) -> "score_dist":
        corrected_true = get_corrected_score(
            self.score_true,
            class_weights,
            multipliers,
        )
        return score_dist(
            score_true=corrected_true,
            score_false=1 - corrected_true,
        )



class ScoreFunction(EENNBase):
    def __init__(
        self,
        question_set: SkillQuestionSet,
    ) -> None:
        super().__init__()
        input_dim = EMBED_DIM + len(question_set) + 1
        # round up to nearest multiple of 8
        self.input_dim = (input_dim + 7) // 8 * 8

        self.attention = AttentionModule(
            dim=self.input_dim,
            num_heads=8,
            num_layers=2,
        )

        self.linear = torch.nn.Linear(self.input_dim, 2)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.question_set = question_set

    def forward(
        self,
        input_sequence: torch.Tensor,
    ) -> torch.Tensor:
        # x: [batch_size, num_responses, embed_dim + self.num_questions]
        # attention_output: [batch_size, num_responses, embed_dim + self.num_questions]
        attention_output, _ = self.attention(input_sequence)
        # logits: [batch_size, num_responses,2]
        logits = self.linear(attention_output)
        # probabilities: [batch_size, num_responses, 2]
        probabilities = self.softmax(logits)
        return probabilities

    def get_score_for_node(
        self,
        node: InterviewNode,
        device: str,
    ) -> score_dist:
        node_tensor = node.get_node_tensor(
            self.input_dim,
            device,
        )
        # node_tensor: [num_responses, embed_dim + num_questions + 1]
        resp = self(node_tensor)
        # resp: [num_responses, 2]

        dist = resp.squeeze(0).detach().cpu().numpy()[0]
        return score_dist(
            score_true=float(dist[1]),
            score_false=float(dist[0]),
        )


def get_embedding(
    target_text: str,
    # model: sentence_transformers.SentenceTransformer
) -> torch.Tensor:
    global _EMBED_MODEL
    if _EMBED_MODEL is None:
        _EMBED_MODEL = sentence_transformers.SentenceTransformer(
            EMBED_MODEL,
        )
    assert _EMBED_MODEL is not None, "Embedding model not initialized"
    return _EMBED_MODEL.encode(target_text, convert_to_tensor=True)
