import numpy as np
import json
import os
import asyncio
from tqdm.asyncio import tqdm_asyncio
from pathlib import Path
from langchain.schema import SystemMessage
from langchain_community.chat_models import ChatOpenAI
import argparse
import random

from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

DEFAULT_API_TIMEOUT = 30

@retry(
    stop=stop_after_attempt(10),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(asyncio.TimeoutError)
)
async def api_call_with_retry(llm, system_message, timeout=DEFAULT_API_TIMEOUT):
    try:
        return await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise  
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise
    
personal_info = json.load(open("config/api_info.json", "r"))
os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

semaphore = asyncio.Semaphore(256)

def load_prompt(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

async def async_propose_factors(llm, features, previous_factors=None, num_factors=10, is_later_iteration=False, uncovered_count=0, random_seed=42):
    async with semaphore:
        random.seed(random_seed)  
        sampled_features = random.sample(features, min(32, len(features)))
        
        feature_examples = "\n\n".join([
            f"Feature {i+1}: {feature.get('feature_name')} - Context: {feature.get('context')}" 
            for i, feature in enumerate(sampled_features)
        ])
        
        prev_factors = ""
        if previous_factors and len(previous_factors) > 0:
            prev_factors = "Previously identified factors:\n" + "\n".join([f"- {factor}" for factor in previous_factors[:4]])
        
        iteration_focus = ""
        if is_later_iteration and uncovered_count > 0:
            iteration_focus = f"\nIMPORTANT: There are still {uncovered_count} features that don't fit into existing factors. Please create NEW and DIFFERENT factors specifically designed to cover these remaining features. These factors MUST be meaningful, descriptive names (not generic placeholders). Avoid proposing factors similar to existing ones."
        
        system_prompt = load_prompt("prompt/12_propose_factors.txt").format(
            num_factors=num_factors,
            feature_examples=feature_examples,
            prev_factors=prev_factors
        ) + iteration_focus
        
        system_message = SystemMessage(content=system_prompt)
        response =  await api_call_with_retry(llm, system_message)
        result = response.generations[0][0].text.strip()
        
        json_result = json.loads(result)
        factors = json_result.get("factors", [])
        
        meaningful_factors = []
        for factor in factors:
            if ("general" in factor.lower() or "dimension" in factor.lower() or 
                "cluster" in factor.lower() or "factor" in factor.lower() or
                factor.strip() == "" or len(factor.strip()) < 3):
                continue
            meaningful_factors.append(factor)
            
        return meaningful_factors[:num_factors]

async def assign_feature_to_factor(llm, feature, factors):
    async with semaphore:
        formatted_factors = "\n".join([f"{i+1}. {factor}" for i, factor in enumerate(factors)])
        
        feature_summary = f"Feature: {feature.get('feature_name')}"
        if 'context' in feature:
            feature_summary += f" (Context: {feature['context']})"
        
        system_prompt = load_prompt("prompt/12_assign_to_factors.txt").format(
            feature=feature_summary,
            formatted_factors=formatted_factors
        )
        
        system_message = SystemMessage(content=system_prompt)
        response = await api_call_with_retry(llm, system_message)
        result = response.generations[0][0].text.strip()
        
        json_result = json.loads(result)
        assignments = [0] * len(factors)
        
        if "assignments" in json_result:
            assignment_value = json_result["assignments"]
            
            if isinstance(assignment_value, list) and assignment_value:
                idx = assignment_value[0]
                if isinstance(idx, int) and 1 <= idx <= len(factors):
                    assignments[idx - 1] = 1
            
            elif isinstance(assignment_value, str) and assignment_value not in ["0", "none", "None"]:
                first_assignment = assignment_value.replace(" ", "").split(",")[0]
                try:
                    idx = int(first_assignment) - 1
                    if 0 <= idx < len(factors):
                        assignments[idx] = 1
                except ValueError:
                    pass
        
        return assignments

def select_optimal_factors(factor_assignment_matrix, overlap_penalty=0.1, not_cover_penalty=1.0, preferred_indices=None, iteration=0, new_indices=None):
    n, d = factor_assignment_matrix.shape
    
    if d == 0:
        print("Warning: No factors to select from!")
        return []
    
    num_factors = min(8 + (iteration * 2), d)
    
    selected = []
    current_coverage = np.zeros(n, dtype=int)
    
    if iteration > 0 and new_indices:
        print(f"Prioritizing {len(new_indices)} new factors for iteration {iteration+1}")
        for idx in new_indices:
            if idx >= d or idx in selected:
                continue
            
            features_covered = np.sum(factor_assignment_matrix[:, idx])
            if features_covered > 0:
                selected.append(idx)
                current_coverage += factor_assignment_matrix[:, idx]
                
                if len(selected) >= num_factors:
                    return selected
    
    if preferred_indices:
        for idx in preferred_indices:
            if idx >= d or idx in selected:
                continue
            
            new_features_covered = np.sum(np.minimum(factor_assignment_matrix[:, idx], 1 - np.minimum(current_coverage, 1)))
            if new_features_covered > 0:
                selected.append(idx)
                current_coverage += factor_assignment_matrix[:, idx]
                
                if len(selected) >= num_factors:
                    return selected
    
    while len(selected) < num_factors:
        best_score = float('-inf')
        best_idx = -1
        
        for j in range(d):
            if j in selected:
                continue
                
            new_coverage = current_coverage + factor_assignment_matrix[:, j]
            
            new_features_covered = np.sum(np.minimum(factor_assignment_matrix[:, j], 1 - np.minimum(current_coverage, 1)))
            overlap_count = np.sum(np.maximum(0, new_coverage - 1))
            not_covered_count = n - np.sum(np.minimum(new_coverage, 1))
            
            score = new_features_covered * 2
            
            if iteration == 0:
                score -= overlap_penalty * overlap_count - not_cover_penalty * not_covered_count
            else:
                score -= (overlap_penalty * 0.5) * overlap_count - (not_cover_penalty * 0.8) * not_covered_count
            
            if score > best_score:
                best_score = score
                best_idx = j
        
        if best_idx == -1:
            if len(selected) == 0:
                coverage_counts = np.sum(factor_assignment_matrix, axis=0)
                if np.max(coverage_counts) > 0:
                    best_idx = np.argmax(coverage_counts)
                    selected.append(best_idx)
                    print(f"Using fallback selection: picking factor with maximum coverage ({coverage_counts[best_idx]} features)")
                    continue
            break
            
        selected.append(best_idx)
        current_coverage += factor_assignment_matrix[:, best_idx]
    
    if len(selected) == 0 and d > 0:
        coverage_counts = np.sum(factor_assignment_matrix, axis=0)
        top_count = min(5, d)  
        top_indices = np.argsort(coverage_counts)[-top_count:]
        selected = [idx for idx in top_indices if coverage_counts[idx] > 0]
    
    return selected

def prune_factors(factors, factor_assignment_matrix, min_factor_fraction=0.01, max_factor_fraction=0.5, iteration=0):
    n = factor_assignment_matrix.shape[0]
    min_features = max(1, int(n * min_factor_fraction))
    max_features = int(n * max_factor_fraction)
    
    if iteration > 0:
        min_features = 1  
    
    factors_to_keep = []
    indices_to_keep = []
    
    for i in range(len(factors)):
        features_covered = np.sum(factor_assignment_matrix[:, i])
        if min_features <= features_covered <= max_features:
            factors_to_keep.append(factors[i])
            indices_to_keep.append(i)
    
    if iteration > 0 and len(factors_to_keep) < 10:
        for i in range(len(factors)):
            if i not in indices_to_keep:
                features_covered = np.sum(factor_assignment_matrix[:, i])
                if features_covered > 0:
                    factors_to_keep.append(factors[i])
                    indices_to_keep.append(i)
    
    if not factors_to_keep:
        coverages = [np.sum(factor_assignment_matrix[:, i]) for i in range(len(factors))]
        top_indices = np.argsort(coverages)[-4:]  
        for i in top_indices:
            if np.sum(factor_assignment_matrix[:, i]) > 0:
                factors_to_keep.append(factors[i])
                indices_to_keep.append(i)
    
    pruned_matrix = factor_assignment_matrix[:, indices_to_keep] if indices_to_keep else np.zeros((n, 0))
    
    print(f"Pruned from {len(factors)} to {len(factors_to_keep)} factors")
    return factors_to_keep, pruned_matrix

async def run_iterative_factorization(features, max_rounds=3, min_factor_fraction=0.0, max_factor_fraction=0.4, proposer_num_factors=16, overlap_penalty=4.0, not_cover_penalty=8.0, model_name="gpt-4o-mini", stop_criteria=None, random_seed=42, coverage_threshold=0.9):
    np.random.seed(random_seed)
    random.seed(random_seed)

    llm = ChatOpenAI(temperature=0.0, model_name=model_name)
    
    n = len(features)
    all_factors = []
    all_feature_factor_matrix = np.zeros((n, 0), dtype=np.int64)
    uncovered_feature_indicators = np.ones(n, dtype=bool)
    selected_factors_memory = {}
    selected_factors = []
    selected_feature_factor_matrix = np.zeros((n, 0), dtype=np.int64)
    
    for iteration in range(max_rounds):
        uncovered_count = np.sum(uncovered_feature_indicators)
        covered_fraction = 1.0 - (uncovered_count / n)
        print(f"\n--- Iteration {iteration+1}/{max_rounds} ---")
        print(f"Uncovered features: {uncovered_count}/{n} (Coverage: {covered_fraction:.1%})")
        
        if covered_fraction >= coverage_threshold:
            print(f"Coverage threshold reached: {covered_fraction:.1%} >= {coverage_threshold:.1%}")
            print(f"Stopping iterations early as {covered_fraction:.1%} of features are now covered")
            break
            
        if stop_criteria and uncovered_count < stop_criteria:
            print(f"Stop criteria met: fewer than {stop_criteria} features remain uncovered")
            break
            
        if iteration >= 1 and uncovered_count > 0.3 * n:
            adjusted_overlap_penalty = overlap_penalty * 0.5
            print(f"Many features still uncovered. Adjusting overlap penalty from {overlap_penalty} to {adjusted_overlap_penalty}")
            overlap_penalty = adjusted_overlap_penalty
        
        focus_features = []
        
        if iteration >= 1 and uncovered_count > 0:
            uncovered_indices = [i for i in range(n) if uncovered_feature_indicators[i]]
            max_uncovered = len(uncovered_indices)
            focus_features.extend([features[uncovered_indices[i]] for i in range(max_uncovered)])
            print(f"Focusing exclusively on {len(focus_features)} uncovered features in iteration {iteration+1}")
        else:
            for i in range(n):
                if uncovered_feature_indicators[i]:
                    focus_features.append(features[i])
                    if len(focus_features) >= 32:
                        break
            
            if len(focus_features) < 32:
                covered_indices = [i for i in range(n) if not uncovered_feature_indicators[i]]
                if covered_indices:
                    ordered_indices = sorted(covered_indices)
                    num_extra = min(32 - len(focus_features), len(ordered_indices))
                    focus_features.extend([features[ordered_indices[i % len(ordered_indices)]] for i in range(num_extra)])
        
        print("Proposing new factors...")
        
        max_attempts = 3 if iteration > 0 and uncovered_count > 10 else 1
        new_factors = []
        
        for attempt in range(max_attempts):
            if attempt > 0:
                print(f"Attempt {attempt+1}: Trying again to generate meaningful factors...")
                
            factors = await async_propose_factors(
                llm, 
                focus_features,
                previous_factors=all_factors,
                num_factors=proposer_num_factors,
                is_later_iteration=(iteration > 0),
                uncovered_count=uncovered_count,
                random_seed=random_seed
            )
            
            if len(factors) > 0:
                new_factors = factors
                break
        
        filtered_factors = []
        for factor in new_factors:
            if factor not in all_factors and factor not in filtered_factors:
                filtered_factors.append(factor)
        
        if iteration >= 1 and len(filtered_factors) == 0:
            print("No new meaningful factors generated. Moving to the next iteration.")
            continue
        
        prioritized_factors = []
        for factor in selected_factors_memory.keys():
            if factor not in all_factors:
                prioritized_factors.append(factor)
        
        new_factors = prioritized_factors + filtered_factors
        all_factors.extend(new_factors)
        
        print(f"Generated {len(new_factors)} new unique factors")
        for factor in new_factors:
            print(f"- {factor}")
        
        if not new_factors:
            print("No new factors generated, moving to next iteration")
            continue
        
        print("Assigning features to factors...")
        assignment_tasks = []
        for feature in features:
            assignment_tasks.append(assign_feature_to_factor(llm, feature, new_factors))
        
        new_assignments = await tqdm_asyncio.gather(*assignment_tasks, desc="Assigning features")
        new_feature_factor_matrix = np.array(new_assignments, dtype=np.int64)
        
        all_feature_factor_matrix = np.concatenate(
            [all_feature_factor_matrix, new_feature_factor_matrix], 
            axis=1
        )
        
        print("Pruning and selecting factors...")
        pruned_factors, pruned_factor_matrix = prune_factors(
            factors=all_factors,
            factor_assignment_matrix=all_feature_factor_matrix,
            min_factor_fraction=min_factor_fraction,
            max_factor_fraction=max_factor_fraction,
            iteration=iteration  
        )
        
        if pruned_factor_matrix.shape[1] == 0:
            print("No factors left after pruning, moving to next iteration")
            continue
        
        new_factor_indices = []
        if filtered_factors:
            new_factor_indices = [
                i for i, factor in enumerate(pruned_factors) 
                if factor in filtered_factors
            ]
            if new_factor_indices:
                print(f"Found {len(new_factor_indices)} new factors that survived pruning")
            
        selected_indices = select_optimal_factors(
            factor_assignment_matrix=pruned_factor_matrix,
            overlap_penalty=overlap_penalty,
            not_cover_penalty=not_cover_penalty,
            preferred_indices=[i for i, factor in enumerate(pruned_factors) if factor in selected_factors_memory],
            iteration=iteration,  
            new_indices=new_factor_indices  
        )
        
        if not selected_indices:
            print("No factors were selected, moving to next iteration")
            continue
            
        selected_factors = [pruned_factors[i] for i in selected_indices]
        selected_feature_factor_matrix = pruned_factor_matrix[:, selected_indices]
        
        for factor in selected_factors:
            if factor in selected_factors_memory:
                selected_factors_memory[factor] += 1
            else:
                selected_factors_memory[factor] = 1
        
        if selected_feature_factor_matrix.shape[1] > 0:
            uncovered_feature_indicators = (np.max(selected_feature_factor_matrix, axis=1) == 0)
        else:
            uncovered_feature_indicators = np.ones(n, dtype=bool)
        
        print(f"Selected {len(selected_factors)} factors:")
        for i, factor in enumerate(selected_factors):
            features_in_factor = np.sum(selected_feature_factor_matrix[:, i])
            print(f"- {factor} ({features_in_factor} features)")
        
        assignments_per_feature = np.sum(selected_feature_factor_matrix, axis=1)
        features_with_multiple = np.sum(assignments_per_feature > 1)
        print(f"Features with multiple assignments: {features_with_multiple}/{n}")
    
    uncovered_count = np.sum(uncovered_feature_indicators)
    if uncovered_count > 0 and len(selected_factors) > 0:
        print(f"\n--- Final allocation step ---")
        print(f"Assigning {uncovered_count} remaining uncovered features to existing factors")
        
        uncovered_indices = [i for i in range(n) if uncovered_feature_indicators[i]]
        uncovered_features = [features[i] for i in uncovered_indices]
        
        assignment_tasks = []
        for feature in uncovered_features:
            assignment_tasks.append(assign_feature_to_factor(llm, feature, selected_factors))
        
        final_assignments = await tqdm_asyncio.gather(*assignment_tasks, desc="Assigning remaining features")
        
        for idx, assignments in zip(uncovered_indices, final_assignments):
            if sum(assignments) > 0:  
                selected_feature_factor_matrix[idx] = assignments
                uncovered_feature_indicators[idx] = False
        
        newly_covered = uncovered_count - np.sum(uncovered_feature_indicators)
        final_covered_fraction = 1.0 - (np.sum(uncovered_feature_indicators) / n)
        print(f"Final allocation covered {newly_covered}/{uncovered_count} remaining features")
        print(f"Final coverage: {final_covered_fraction:.1%}")
    
    final_factors = {}
    for i, factor in enumerate(selected_factors):
        feature_indices = [j for j in range(n) if selected_feature_factor_matrix[j, i] == 1]
        final_factors[factor] = [features[j] for j in feature_indices]
    
    result = {
        "selected_factors": selected_factors,
        "factors": final_factors,
        "uncovered_count": int(np.sum(uncovered_feature_indicators)),
        "uncovered_features": [features[i] for i in range(n) if uncovered_feature_indicators[i]]
    }
    
    return result

async def process_user_data(input_data, args):
    results = []
    
    if args.num_users > 0:
        input_data = input_data[:args.num_users]
    
    for user_index, user in enumerate(input_data):
        user_id = user.get('user_id')
        print(f"\n\nProcessing user {user_id} ({user_index+1}/{len(input_data)})")
        
        all_features = []
        for item in user.get('profile', []):
            for feature in item['feature']:
                feature['user_id'] = user_id
                feature['item_id'] = item.get('item_id')
                all_features.append(feature)
        
        print(f"Found {len(all_features)} individual features for user {user_id}")
        
        if not all_features:
            print(f"No features found for user {user_id}, skipping...")
            continue
        
        user_result = await run_iterative_factorization(
            features=all_features,
            max_rounds=args.max_rounds,
            min_factor_fraction=args.min_factor_fraction,
            max_factor_fraction=args.max_factor_fraction,
            proposer_num_factors=args.proposer_factors,
            overlap_penalty=args.overlap_penalty,
            not_cover_penalty=args.not_cover_penalty,
            model_name=args.model_name,
            stop_criteria=args.stop_criteria,
            coverage_threshold=args.coverage_threshold
        )
        
        results.append({
            "user_id": user_id,
            "factorization": user_result
        })
    
    return results

async def main():
    """
    Main function to parse arguments and run the feature factorization process.
    
    This function:
    1. Parses command-line arguments
    2. Loads the feature data from the input file
    3. Processes user data to extract and factorize features
    4. Saves the results to the specified output file
    """
    parser = argparse.ArgumentParser(description="Run iterative factorization on user features")
    parser.add_argument("--num_users", type=int, default=0,
                        help="Number of users to process (0 for all)")
    parser.add_argument("--input_path", type=str, default="result/lamp3_only_feature.json", 
                        help="Path to the input file with feature data")
    parser.add_argument("--output_path", type=str, default="result/lamp3_only_factor.json",
                        help="Path to save the factorization results")
    parser.add_argument("--max_rounds", type=int, default=3,
                        help="Maximum number of factorization iterations")
    parser.add_argument("--min_factor_fraction", type=float, default=0.0,
                        help="Minimum fraction of features in a factor")
    parser.add_argument("--max_factor_fraction", type=float, default=0.4,
                        help="Maximum fraction of features in a factor")
    parser.add_argument("--proposer_factors", type=int, default=16,
                        help="Number of factors to propose in each round")
    parser.add_argument("--overlap_penalty", type=float, default=0.1,
                        help="Penalty for features appearing in multiple factors")
    parser.add_argument("--not_cover_penalty", type=float, default=2.0,
                        help="Penalty for features not covered by any factor")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini",
                        help="Name of the language model to use")
    parser.add_argument("--stop_criteria", type=int, default=16,
                        help="Stop if fewer than this many features remain uncovered")
    parser.add_argument("--coverage_threshold", type=float, default=0.95,
                        help="Stop if this fraction of features is covered (default: 0.9)")
    
    args = parser.parse_args()
    
    try:
        with open(args.input_path, 'r') as f:
            input_data = json.load(f)
            
        results = await process_user_data(input_data, args)
        
        output_path = Path(args.output_path)
        output_path.parent.mkdir(exist_ok=True, parents=True)
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {args.output_path}")
            
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    asyncio.run(main())