                                                               

import json
import argparse
import sys
from datetime import datetime
from pathlib import Path
from collections import Counter, defaultdict
from distribution_similarity import (
    calculate_similarity_by_method,
    SUPPORTED_SIMILARITY_METHODS,
    METHOD_DESCRIPTIONS
)

      
DEFAULT_WINDOW_SIZE = 2
DEFAULT_THRESHOLD = 0.001
SUPPORTED_WINDOW_UNITS = ["months", "days"]
DEFAULT_WINDOW_UNIT = "months"


def validate_input_file(file_path):
                    
    if not Path(file_path).exists():
        raise FileNotFoundError(f"Input file not found: {file_path}")
    return True


def get_output_path(input_file_path):
                          
    input_path = Path(input_file_path)
    output_dir = input_path.parent
    return output_dir / "question_segments_results.json"


def parse_args():
                 
    parser = argparse.ArgumentParser(
        description="Temporal segmentation analysis tool based on question type distribution changes",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Usage examples:
    python temporal_segmentation.py /path/to/question_data.jsonl
    python temporal_segmentation.py /path/to/question_data.jsonl --window_size 10 --threshold 2.0
    python temporal_segmentation.py /path/to/question_data.jsonl --window_unit days --window_size 30
    python temporal_segmentation.py /path/to/question_data.jsonl --similarity_method weighted --weight_power 2.0
    python temporal_segmentation.py /path/to/question_data.jsonl --similarity_method kl --threshold 1.5 --window_unit days

Similarity calculation method descriptions:
    - kl: Standard symmetric KL divergence, treats all categories equally
    - weighted: Weighted squared difference, gives higher weight to high-frequency categories
    - js_weighted: Weighted Jensen-Shannon divergence, well-balanced, gives higher weight to high-frequency categories (recommended)
        """
    )

    parser.add_argument(
        'input_file',
        help="Input question data file path (.jsonl format)"
    )

    parser.add_argument(
        '--window_size',
        type=int,
        default=DEFAULT_WINDOW_SIZE,
        help=f"Sliding window size, default: {DEFAULT_WINDOW_SIZE} (used with --window_unit)"
    )

    parser.add_argument(
        '--window_unit',
        choices=SUPPORTED_WINDOW_UNITS,
        default=DEFAULT_WINDOW_UNIT,
        help=f"Window unit, supported: {', '.join(SUPPORTED_WINDOW_UNITS)}, default: {DEFAULT_WINDOW_UNIT}"
    )

    parser.add_argument(
        '--threshold',
        type=float,
        default=DEFAULT_THRESHOLD,
        help=f"Distribution change detection threshold, default: {DEFAULT_THRESHOLD}"
    )

    parser.add_argument(
        '--verbose',
        action='store_true',
        help="Show detailed output information"
    )

    parser.add_argument(
        '--similarity_method',
        type=str,
        choices=SUPPORTED_SIMILARITY_METHODS,
        default='js_weighted',
        help=f"Distribution similarity calculation method. Supported methods: {', '.join(SUPPORTED_SIMILARITY_METHODS)}"
    )

    parser.add_argument(
        '--weight_power',
        type=float,
        default=1.5,
        help="Weight exponent for weighted methods, default 1.5. Higher values give more weight to high-frequency categories"
    )

    return parser.parse_args()


def load_data(file_path):
                     
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    continue
    except FileNotFoundError:
        sys.exit(1)
    except Exception as e:
        sys.exit(1)

    return data


def parse_date(date_str):
                 
    if not date_str:
        return None

    try:
        return datetime.strptime(date_str, '%Y/%m/%d')
    except ValueError:
        try:
            return datetime.strptime(date_str, '%Y-%m-%d')
        except ValueError:
            try:
                return datetime.strptime(date_str, '%Y.%m.%d')
            except ValueError:
                return None


def detect_distribution_changes(data, window_size=6, threshold=0.005,
                                similarity_method='js_weighted', weight_power=1.5,
                                window_unit='months'):
                              
           
    sorted_data = sorted(data, key=lambda x: parse_date(x['date']))

                
    if window_unit == 'months':
               
        grouped_data = defaultdict(list)
        for item in sorted_data:
            date = parse_date(item['date'])
            if date:
                                   
                key = date.strftime('%Y-%m')
                grouped_data[key].append(item)
        time_keys = sorted(grouped_data.keys())
    else:        
              
        grouped_data = defaultdict(list)
        for item in sorted_data:
            date = parse_date(item['date'])
            if date:
                                       
                key = date.strftime('%Y-%m-%d')
                grouped_data[key].append(item)
        time_keys = sorted(grouped_data.keys())

    if len(time_keys) < window_size * 2:
                             
        if window_unit == 'months':
            return [datetime.strptime(time_keys[0] + '-01', '%Y-%m-%d'),
                    datetime.strptime(time_keys[-1] + '-01', '%Y-%m-%d')], grouped_data, time_keys
        else:        
            return [datetime.strptime(time_keys[0], '%Y-%m-%d'),
                    datetime.strptime(time_keys[-1], '%Y-%m-%d')], grouped_data, time_keys

    if window_unit == 'months':
        change_points = [datetime.strptime(time_keys[0] + '-01', '%Y-%m-%d')]       
    else:        
        change_points = [datetime.strptime(time_keys[0], '%Y-%m-%d')]       

               
    i = 0

    while i <= len(time_keys) - window_size:
                    
        found_change = False
        for j in range(i + 1, len(time_keys) - window_size + 1):
            next_window = time_keys[j:j + window_size]
            current_window = time_keys[i:j + window_size]

                                                
            current_types = []
            for time_key in current_window:
                for item in grouped_data[time_key]:
                    current_types.append(item['question_type'])
            current_dist = Counter(current_types)

                            
            next_types = []
            for time_key in next_window:
                for item in grouped_data[time_key]:
                    next_types.append(item['question_type'])
            next_dist = Counter(next_types)

                           
            if len(current_types) > 0 and len(next_types) > 0:
                divergence, method_name = calculate_similarity_by_method(
                    current_dist, next_dist, similarity_method, weight_power
                )

                if divergence > threshold:
                             
                    if window_unit == 'months':
                        change_points.append(datetime.strptime(time_keys[j] + '-01', '%Y-%m-%d'))
                    else:        
                        change_points.append(datetime.strptime(time_keys[j], '%Y-%m-%d'))
                    i = j            
                    found_change = True
                    break

        if not found_change:
                                    
            break

    if window_unit == 'months':
        change_points.append(datetime.strptime(time_keys[-1] + '-01', '%Y-%m-%d'))      
    else:        
        change_points.append(datetime.strptime(time_keys[-1], '%Y-%m-%d'))      

    return change_points, grouped_data, time_keys


def create_time_segments(change_points, grouped_data, time_keys, window_unit='months'):
                               
    segments = []

                               
    for i in range(len(change_points) - 1):
        start_point = change_points[i]
        end_point = change_points[i + 1]

                     
        if window_unit == 'months':
            start_key = start_point.strftime('%Y-%m')
            end_key = end_point.strftime('%Y-%m')
        else:        
            start_key = start_point.strftime('%Y-%m-%d')
            end_key = end_point.strftime('%Y-%m-%d')

        try:
            start_idx = time_keys.index(start_key)
            end_idx = time_keys.index(end_key)
        except ValueError:
                              
            continue

                    
        segment_time_keys = time_keys[start_idx:end_idx]

                   
        if not segment_time_keys:
            continue

                   
        segment_data = []
        for time_key in segment_time_keys:
            segment_data.extend(grouped_data[time_key])

        if segment_data:                
            segments.append({
                'start_date': start_point,
                'end_date': end_point,
                'time_keys': segment_time_keys,
                'data': segment_data
            })

                           
    segments = remove_overlaps(segments, time_keys, grouped_data)

    return segments


def remove_overlaps(segments, time_keys, grouped_data):
                               
    if len(segments) <= 1:
        return segments

    cleaned_segments = []

    for i in range(len(segments)):
        current_segment = segments[i].copy()

                    
        if i < len(segments) - 1:
            next_segment = segments[i + 1]
            current_time_keys = current_segment['time_keys']
            next_time_keys = next_segment['time_keys']

                      
            overlap_keys = set(current_time_keys) & set(next_time_keys)

            if overlap_keys:
                            
                overlap_start = min(overlap_keys)
                overlap_start_idx = time_keys.index(overlap_start)
                current_start_idx = time_keys.index(current_time_keys[0])

                              
                new_end_idx = overlap_start_idx
                if new_end_idx > current_start_idx:
                    new_time_keys = time_keys[current_start_idx:new_end_idx]

                    if new_time_keys:           
                                
                        segment_data = []
                        for time_key in new_time_keys:
                            segment_data.extend(grouped_data[time_key])

                        current_segment['time_keys'] = new_time_keys
                                       
                        if '-' in new_time_keys[-1] and len(new_time_keys[-1].split('-')) == 3:
                                            
                            current_segment['end_date'] = datetime.strptime(new_time_keys[-1], '%Y-%m-%d')
                        else:
                                         
                            current_segment['end_date'] = datetime.strptime(new_time_keys[-1] + '-01', '%Y-%m-%d')
                        current_segment['data'] = segment_data
                else:
                                  
                    continue

        cleaned_segments.append(current_segment)

    return cleaned_segments


def analyze_segment(segment):
                         
    question_types = [item['question_type'] for item in segment['data']]
    topics = [item['topic'] for item in segment['data']]

              
    type_dist = Counter(question_types)
    total_questions = len(question_types)

           
    type_percentages = {qtype: (count / total_questions) * 100
                        for qtype, count in type_dist.items()}

            
    topic_dist = Counter()
    for topic_str in topics:
        if topic_str:
                    
            for topic in topic_str.split('|'):
                topic = topic.strip()
                if topic:
                    topic_dist[topic] += 1

    return {
        'total_questions': total_questions,
        'type_distribution': type_dist,
        'type_percentages': type_percentages,
        'topic_distribution': topic_dist
    }


def main():
             
             
    args = parse_args()

            
    try:
        validate_input_file(args.input_file)
    except FileNotFoundError as e:
        sys.exit(1)

    if args.verbose:
        if args.similarity_method in ['weighted', 'js_weighted']:
            print(f"Using {args.similarity_method} similarity method with weight power {args.weight_power}")

    data = load_data(args.input_file)

                
    valid_data = [item for item in data if parse_date(item['date'])]
    invalid_count = len(data) - len(valid_data)

    if invalid_count > 0:
        print(f"Warning: {invalid_count} invalid data found")

    if len(valid_data) == 0:
        sys.exit(1)

             
    change_points, grouped_data, time_keys = detect_distribution_changes(
        valid_data,
        window_size=args.window_size,
        threshold=args.threshold,
        similarity_method=args.similarity_method,
        weight_power=args.weight_power,
        window_unit=args.window_unit
    )

            
    segments = create_time_segments(change_points, grouped_data, time_keys, args.window_unit)

    if len(segments) == 0:
        return

    for i, segment in enumerate(segments):
        analysis = analyze_segment(segment)

        if args.window_unit == 'months':
            start_str = segment['start_date'].strftime('%Y-%m')
            end_str = segment['end_date'].strftime('%Y-%m')
            time_count = len(segment['time_keys'])
            unit_str = f"{time_count} months"
        else:        
            start_str = segment['start_date'].strftime('%Y-%m-%d')
            end_str = segment['end_date'].strftime('%Y-%m-%d')
            time_count = len(segment['time_keys'])
            unit_str = f"{time_count} days"


        if analysis['total_questions'] > 0:
            for qtype, percentage in sorted(analysis['type_percentages'].items(),
                                            key=lambda x: x[1], reverse=True):
                count = analysis['type_distribution'][qtype]


             
    output_path = get_output_path(args.input_file)

    results = {
        'input_file': args.input_file,
        'analysis_parameters': {
            'window_size': args.window_size,
            'window_unit': args.window_unit,
            'threshold': args.threshold,
            'similarity_method': args.similarity_method,
            'weight_power': args.weight_power if args.similarity_method in ['weighted', 'js_weighted'] else None,
            'method_description': METHOD_DESCRIPTIONS.get(args.similarity_method, ''),
            'total_data_points': len(valid_data),
            'analysis_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        },
        'change_points': [cp.strftime('%Y-%m-%d') for cp in change_points],
        'segments': []
    }

    for i, segment in enumerate(segments):
        analysis = analyze_segment(segment)
        results['segments'].append({
            'segment_id': i + 1,
            'start_date': segment['start_date'].strftime('%Y-%m-%d'),
            'end_date': segment['end_date'].strftime('%Y-%m-%d'),
            'time_keys': segment['time_keys'],
            'time_units_count': len(segment['time_keys']),
            'total_questions': analysis['total_questions'],
            'type_distribution': dict(analysis['type_distribution']),
            'type_percentages': analysis['type_percentages'],
            'top_topics': dict(analysis['topic_distribution'].most_common(20))
        })

    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
    except Exception as e:
        sys.exit(1)


if __name__ == "__main__":
    main()
