from typing import List, Dict, Any, Optional, Tuple
from algorithm.domain import Domain
from torch_geometric.data import Data
import random
import torch
import numpy as np
import ipdb
class Env:
    def __init__(self, assignments: Dict, domains: Dict, constraints: List[Tuple[Any, List]], 
                 vconstraints: Dict[Any, List[Tuple[Any, List]]]):
        self.assignments = assignments
        self.domains = domains
        self.constraints = constraints
        self.vconstraints = vconstraints
        self.var_num = len(self.domains)
        self.done = False

    def copy(self) -> 'Env':
        new_assignments = self.assignments.copy()
        new_domains = {v: d.copy_domain() for v, d in self.domains.items()}
        return Env(new_assignments, new_domains, self.constraints, self.vconstraints)

    def apply_action(self, var: Any, value: Any) -> bool:
        self.assignments[var] = value
        original_domains = {v: d.copy_domain() for v, d in self.domains.items()}
        self.domains[var] = Domain(set([value]))
        
        for constraint, variables in self.vconstraints.get(var, []):
            if any(v not in self.assignments for v in variables):
                if not constraint(variables, self.domains, self.assignments, forwardcheck=True):
                    del self.assignments[var]
                    self.domains = original_domains
                    return False 
        return True

    def is_terminal(self) -> bool:
        return len(self.assignments) == len(self.domains) 

    def is_valid(self) -> bool:
        for constraint, variables in self.constraints:
            if all(v in self.assignments for v in variables):
                if not constraint(variables, self.domains, self.assignments):
                    print(f"is_valid failed: constraint {variables} not satisfied, assignments={self.assignments}")
                    return False
        return True

    def get_unassigned_vars(self) -> List[Any]:
        return [v for v in self.domains if v not in self.assignments]

    def get_constraint_satisfaction(self) -> float:
        satisfied = 0
        total = 0
        for constraint, variables in self.constraints:
            if all(v in self.assignments for v in variables):
                total += 1
                if constraint(variables, self.domains, self.assignments):
                    satisfied += 1
        return satisfied / total if total > 0 else 0.0

    def get_state_representation(self) -> Dict:
        variable_feature = self.get_variable_feature()
        constraint_feature = self.get_constraint_feature()
        edge_index = self.get_edges()
        if not isinstance(variable_feature, torch.Tensor):
            variable_feature = torch.tensor(variable_feature, dtype=torch.float)
        if not isinstance(constraint_feature, torch.Tensor):
            constraint_feature = torch.tensor(constraint_feature, dtype=torch.float)
        if not isinstance(edge_index, torch.Tensor):
            edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        state = Data(x=variable_feature, c=constraint_feature, edge_index=edge_index, var_num=self.var_num)

        device = torch.device('cuda:1')
        
        state = state.to(device)
        
        return state

    def get_variable_feature(self):
        variable_feature = []
        for variable in range(self.var_num):
            domain_len = len(self.domains[variable])
            is_assigned = (variable in self.assignments)
            variable_feature.append((domain_len, is_assigned))
        return variable_feature

    def get_constraint_feature(self):
        constraint_feature = []
        for constraint in self.constraints:
            ddeg = constraint[0].get_ddeg(constraint[1], self.domains)    
            num_variable = 2
            cost = constraint[0].get_min_cost()
            constraint_feature.append((ddeg, num_variable, cost))
        return constraint_feature

    def get_edges(self):
        edge_index = []
        for index, constraint in enumerate(self.constraints):
            var1, var2 = constraint[1]
            edge_index.append((var1, self.var_num + index))
            edge_index.append((var2, self.var_num + index))
        return edge_index

    def get_actions(self, policy, epsilon=0.1) -> List[Tuple[Any, Any]]:
        unassigned = self.get_unassigned_vars()
        if not unassigned:
            return []
        if policy is not None:
            if random.random() < epsilon:
                var = random.choice(unassigned)
            else:
                unassigned_probs = [(v, policy[v]) for v in unassigned]
                var = max(unassigned_probs, key=lambda x: x[1])[0]
        else:
            var = min(unassigned, key=lambda v: len(self.env.domains[v]))
        valid_values = []
        for value in self.domains[var]:
            temp_env = self.copy()
            if temp_env.apply_action(var, value):
                score = temp_env.get_constraint_satisfaction()
                valid_values.append((value, score))
        
        values = sorted(valid_values, key=lambda x: x[1], reverse=True)[:]  # 限制 2 个值
        if valid_values == []: self.done = True
        return [(var, value) for value, _ in values]
    
    def get_gap(self):
        if not self.is_valid():
            return 1
        gap = 0
        for constraint, variables in self.constraints:
            if all(v in self.assignments for v in variables):
                gap += 1 / len(self.constraints) *constraint[0].get_gap(variables, self.domains, self.assignments)
        return gap

    def __str__(self) -> str:
        assignments_str = str(self.assignments) if self.assignments else "empty"
        return f"Assignments={assignments_str}"