import { NextResponse } from 'next/server';
import fs from 'fs/promises';
import path from 'path';
import { parse } from 'csv-parse/sync';

// --- Configuration ---
const SHARD_NAMES = [
    'keywords', 'questions', 'thesis', 'search_boost',
    'query_match_1', 'query_match_2', 'query_match_3'
];

// weights from result of optimization
const DEFAULT_WEIGHTS: Record<string, number> = {
    'keywords': 0.134207,
    'questions': 0.226103,
    'thesis': 0.094972,
    'search_boost': 0.029563,
    'query_match_1': 0.217395,
    'query_match_2': 0.241111,
    'query_match_3': 0.056650,
};

const MODEL_NAME = 'Romelianism/MedEmbed-small-v0.1';

// Define the pipeline function type
type PipelineFunction = (
    task: string,
    model: string,
    options?: { quantized?: boolean }
) => Promise<(text: string, options?: { pooling?: string; normalize?: boolean }) => Promise<{ data: Float32Array | number[] }>>;

// --- Module-scope cache for embeddings, model, and records ---
let SHARDS: Record<string, { id: string; vector: number[] }[]> | null = null;
let pipeline: PipelineFunction | null = null;
let RECORDS: Record<string, RecordData> | null = null;
let BM25_INDEX: BM25Index | null = null;

interface ShardDoc {
    id: string;
    vector: number[];
}

interface RecordData {
    device_model?: string;
    company?: string;
    submission_number?: string;
    date_of_final_decision?: string;
    summary_pdf_link?: string;
    thesis?: string;
    summary_keywords?: string;
    concepts?: string;
    [key: string]: string | undefined;
}

// --- BM25 Implementation ---
class BM25Index {
    private documents: Record<string, string> = {};
    private docFreq: Record<string, number> = {};
    private termFreq: Record<string, Record<string, number>> = {};
    private docLengths: Record<string, number> = {};
    private avgDocLength = 0;
    private totalDocs = 0;
    private k1 = 1.2;
    private b = 0.75;

    constructor(documents: Record<string, string>) {
        this.documents = documents;
        this.buildIndex();
    }

    private tokenize(text: string): string[] {
        return text
            .toLowerCase()
            .replace(/[^\w\s]/g, ' ')
            .split(/\s+/)
            .filter(token => token.length > 0);
    }

    private buildIndex(): void {
        this.totalDocs = Object.keys(this.documents).length;
        let totalLength = 0;

        // Build term frequencies and document lengths
        for (const [docId, content] of Object.entries(this.documents)) {
            const tokens = this.tokenize(content);
            this.docLengths[docId] = tokens.length;
            totalLength += tokens.length;

            const termCounts: Record<string, number> = {};
            for (const token of tokens) {
                termCounts[token] = (termCounts[token] || 0) + 1;
            }
            this.termFreq[docId] = termCounts;

            // Update document frequency
            const uniqueTerms = new Set(tokens);
            for (const term of uniqueTerms) {
                this.docFreq[term] = (this.docFreq[term] || 0) + 1;
            }
        }

        this.avgDocLength = totalLength / this.totalDocs;
    }

    public search(query: string, topK: number = 10): Array<{ id: string; score: number }> {
        const queryTerms = this.tokenize(query);
        const scores: Record<string, number> = {};

        for (const docId of Object.keys(this.documents)) {
            let score = 0;
            
            for (const term of queryTerms) {
                const tf = this.termFreq[docId]?.[term] || 0;
                const df = this.docFreq[term] || 0;
                
                if (tf > 0 && df > 0) {
                    const idf = Math.log((this.totalDocs - df + 0.5) / (df + 0.5));
                    const docLength = this.docLengths[docId];
                    const normalizedTF = (tf * (this.k1 + 1)) / 
                        (tf + this.k1 * (1 - this.b + this.b * (docLength / this.avgDocLength)));
                    
                    score += idf * normalizedTF;
                }
            }
            
            if (score > 0) {
                scores[docId] = score;
            }
        }

        return Object.entries(scores)
            .sort(([, a], [, b]) => b - a)
            .slice(0, topK)
            .map(([id, score]) => ({ id, score }));
    }
}

// --- Helper Functions ---

/**
 * Lazy-loads the Xenova/Transformers.js pipeline.
 * Caches the pipeline function for subsequent calls.
 */
async function getPipeline(): Promise<PipelineFunction> {
    if (pipeline === null) {
        const { pipeline: p } = await import('@xenova/transformers');
        pipeline = p as PipelineFunction;
    }
    return pipeline;
}

/**
 * Loads the main records CSV file and caches it as a Record mapping submission numbers to row data.
 */
async function loadRecords(): Promise<Record<string, RecordData>> {
    if (RECORDS === null) {
        console.log("Loading records CSV for the first time...");
        const filePath = path.join(process.cwd(), 'api', 'fda_ai_records.csv');
        const fileContent = await fs.readFile(filePath, 'utf8');
        
        const parsedRecords = parse(fileContent, {
            columns: true,
            skip_empty_lines: true,
            trim: true,
        }) as RecordData[];

        const validRecords = parsedRecords.filter(record => record.thesis !== "Error");

        const records: Record<string, RecordData> = {};
        for (const record of validRecords) {
            const submissionNumber = record.submission_number;
            if (submissionNumber) {
                records[submissionNumber] = record;
            }
        }

        RECORDS = records;
        console.log(`Records loaded and cached. Found ${Object.keys(RECORDS).length} records.`);
    }
    return RECORDS;
}

/**
 * Builds BM25 index from records
 */
async function buildBM25Index(records: Record<string, RecordData>): Promise<BM25Index> {
    if (BM25_INDEX === null) {
        console.log("Building BM25 index for the first time...");
        const documents: Record<string, string> = {};
        
        // Fields to include in BM25 search - matching your Python implementation
        const fieldsToConcat = ['summary_keywords', 'thesis', 'concepts'];
        
        for (const [submissionNumber, record] of Object.entries(records)) {
            const docText = fieldsToConcat
                .map(field => record[field] || '')
                .filter(text => text.length > 0)
                .join(' ');
            
            if (docText.trim()) {
                documents[submissionNumber] = docText;
            }
        }
        
        BM25_INDEX = new BM25Index(documents);
        console.log(`BM25 index built with ${Object.keys(documents).length} documents.`);
    }
    return BM25_INDEX;
}

/**
 * Generates an embedding for the given text using the specified model.
 */
async function embedText(text: string): Promise<number[]> {
    const p = await getPipeline();
    const embedder = await p('feature-extraction', MODEL_NAME, {
        quantized: true,
    });
    const result = await embedder(text, { pooling: 'mean', normalize: true });
    return Array.from(result.data);
}

/**
 * Loads all embedding shards from the JSON files in the public directory.
 */
async function loadAllShards(): Promise<Record<string, ShardDoc[]>> {
    if (SHARDS === null) {
        console.log("Loading embedding shards for the first time...");
        const loadedShards: Record<string, ShardDoc[]> = {};
        const loadPromises = SHARD_NAMES.map(async (name) => {
            const filePath = path.join(process.cwd(), 'public', 'embeddings', `${name}.json`);
            try {
                const fileContent = await fs.readFile(filePath, 'utf8');
                loadedShards[name] = JSON.parse(fileContent) as ShardDoc[];
                console.log(`Successfully loaded and parsed shard: ${name}`);
            } catch (error) {
                console.error(`Failed to load shard ${name} from ${filePath}:`, error);
            }
        });
        await Promise.all(loadPromises);
        SHARDS = loadedShards;
        console.log("All shards loaded and cached.");
    }
    return SHARDS;
}

/**
 * Calculates the cosine similarity between two vectors.
 */
function cosineSim(vecA: number[], vecB: number[]): number {
    let dotProduct = 0;
    let normA = 0;
    let normB = 0;
    for (let i = 0; i < vecA.length; i++) {
        dotProduct += vecA[i] * vecB[i];
        normA += vecA[i] * vecA[i];
        normB += vecB[i] * vecB[i];
    }
    if (normA === 0 || normB === 0) {
        return 0;
    }
    return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

/**
 * Normalizes scores to 0-1 range
 */
function normalizeScores(scores: Record<string, number>): Record<string, number> {
    const values = Object.values(scores);
    if (values.length === 0) return scores;
    
    const maxScore = Math.max(...values);
    const minScore = Math.min(...values);
    const range = maxScore - minScore;
    
    if (range === 0) return Object.fromEntries(Object.keys(scores).map(k => [k, 1.0]));
    
    const normalized: Record<string, number> = {};
    for (const [key, value] of Object.entries(scores)) {
        normalized[key] = (value - minScore) / range;
    }
    return normalized;
}

// --- Search Functions ---

async function performEmbeddingSearch(query: string, topK: number, weights: Record<string, number>, allShards: Record<string, ShardDoc[]>, allRecords: Record<string, RecordData>) {
    const queryVector = await embedText(query);
    const docScores: Record<string, { [key: string]: number }> = {};

    for (const shardName of SHARD_NAMES) {
        const shardDocs = allShards[shardName];
        if (!shardDocs) continue;

        for (const doc of shardDocs) {
            const score = cosineSim(queryVector, doc.vector);
            if (!docScores[doc.id]) {
                docScores[doc.id] = {};
            }
            docScores[doc.id][shardName] = score;
        }
    }

    const rankedResults = Object.entries(docScores).map(([id, scores]) => {
        let hybridScore = 0;
        for (const shardName in scores) {
            hybridScore += (scores[shardName] || 0) * (weights[shardName] || 0);
        }
        return { id, score: hybridScore };
    });

    const sortedResults = rankedResults.sort((a, b) => b.score - a.score);
    const topMatches = sortedResults.slice(0, topK);
    
    console.log(`\n🔎 Top ${topK} matches for query: "${query}"`);
    topMatches.forEach(({ id, score }, index) => {
        const shardScores = docScores[id];
        const formattedShardScores = SHARD_NAMES.map(name => `${name}: ${(shardScores[name] || 0).toFixed(4)}`).join(', ');
        console.log(`#${index + 1}: ID=${id}, Hybrid Score=${score.toFixed(4)} | [${formattedShardScores}]`);
    });
    
    return topMatches.map(match => {
        const record = allRecords[match.id] || {};
        return {
            deviceName: record.device_model || "N/A",
            applicant: record.company || "N/A",
            submissionNumber: match.id,
            decisionDate: record.date_of_final_decision || "N/A",
            similarity: match.score,
            pdfLink: record.summary_pdf_link || null,
            thesis: record.thesis || "N/A",
            keywords: record.summary_keywords || "N/A",
            concepts: record.concepts || "N/A",
        };
    });
}

/**
 * Performs hybrid embedding + BM25 search
 */
async function performEmbeddingBM25Search(
    query: string, 
    topK: number, 
    weights: Record<string, number>, 
    allShards: Record<string, ShardDoc[]>, 
    allRecords: Record<string, RecordData>,
    lambdaVal: number = 0.5 // Weight between embedding (lambda) and BM25 (1-lambda)
) {
    const embeddingWeight = lambdaVal;
    const bm25Weight = 1 - lambdaVal;

    // 1. Get embedding results (fetch more to ensure good overlap)
    const embeddingTopK = Math.max(topK * 3, 100);
    const embeddingResults = await performEmbeddingSearch(query, embeddingTopK, weights, allShards, allRecords);
    
    const embeddingScores: Record<string, number> = {};
    for (const result of embeddingResults) {
        embeddingScores[result.submissionNumber] = result.similarity;
    }
    
    // 2. Get BM25 results
    const bm25Index = await buildBM25Index(allRecords);
    const bm25Results = bm25Index.search(query, embeddingTopK);
    
    const bm25Scores: Record<string, number> = {};
    for (const result of bm25Results) {
        bm25Scores[result.id] = result.score;
    }

    // 3. Normalize scores
    const normalizedEmbedding = normalizeScores(embeddingScores);
    const normalizedBM25 = normalizeScores(bm25Scores);

    // 4. Combine scores
    const combinedScores: Record<string, number> = {};
    const allSubmissions = new Set([...Object.keys(normalizedEmbedding), ...Object.keys(normalizedBM25)]);

    for (const submissionNumber of allSubmissions) {
        const embScore = normalizedEmbedding[submissionNumber] || 0.0;
        const bm25Score = normalizedBM25[submissionNumber] || 0.0;
        combinedScores[submissionNumber] = (embeddingWeight * embScore) + (bm25Weight * bm25Score);
    }

    // 5. Sort and format results
    const sortedResults = Object.entries(combinedScores)
        .sort(([, a], [, b]) => b - a)
        .slice(0, topK);

    console.log(`\n🔎 Hybrid Embedding+BM25 search (λ=${lambdaVal}) for query: "${query}"`);
    sortedResults.forEach(([id, score], index) => {
        const embScore = normalizedEmbedding[id] || 0;
        const bm25Score = normalizedBM25[id] || 0;
        console.log(`#${index + 1}: ID=${id}, Combined=${score.toFixed(4)} | EMB=${embScore.toFixed(4)}, BM25=${bm25Score.toFixed(4)}`);
    });

    return sortedResults.map(([submissionNumber, combinedScore]) => {
        const record = allRecords[submissionNumber] || {};
        return {
            deviceName: record.device_model || "N/A",
            applicant: record.company || "N/A",
            submissionNumber,
            decisionDate: record.date_of_final_decision || "N/A",
            similarity: combinedScore,
            pdfLink: record.summary_pdf_link || null,
            thesis: record.thesis || "N/A",
            keywords: record.summary_keywords || "N/A",
            concepts: record.concepts || "N/A",
            // Debug info
            embeddingScore: normalizedEmbedding[submissionNumber] || 0,
            bm25Score: normalizedBM25[submissionNumber] || 0,
        };
    });
}

function performKeywordSearch(query: string, page: number, limit: number, allRecords: Record<string, RecordData>) {
    const allRecordsArray = Object.values(allRecords);
    let filteredRecords = allRecordsArray;

    if (query) {
        const queryTerms = query.toLowerCase().split(/\s+/).filter(term => term);
        filteredRecords = allRecordsArray.filter(record => {
            const searchableText = [
                record.device_model,
                record.company,
                record.thesis,
                record.summary_keywords,
                record.concepts,
                record.submission_number
            ].join(' ').toLowerCase();
            return queryTerms.every(term => searchableText.includes(term));
        });
    }

    const totalRecords = filteredRecords.length;
    const totalPages = Math.ceil(totalRecords / limit);
    const startIndex = (page - 1) * limit;
    const paginatedRecords = filteredRecords.slice(startIndex, startIndex + limit);

    const results = paginatedRecords.map(record => ({
        deviceName: record.device_model || "N/A",
        applicant: record.company || "N/A",
        submissionNumber: record.submission_number || "N/A",
        decisionDate: record.date_of_final_decision || "N/A",
        pdfLink: record.summary_pdf_link || null,
        thesis: record.thesis || "N/A",
        keywords: record.summary_keywords || "N/A",
        concepts: record.concepts || "N/A",
    }));

    return {
        results,
        currentPage: page,
        totalPages,
        totalRecords,
        message: totalRecords === 0 ? "No results found for your search." : ""
    };
}

// --- API Route Handler ---

export async function POST(req: Request) {
    try {
        const [allShards, allRecords] = await Promise.all([loadAllShards(), loadRecords()]);
        const body = await req.json();

        const {
            query = '',
            topK = 300,
            weights = DEFAULT_WEIGHTS,
            mode = 'embedding', // 'embedding', 'keyword', 'hybrid'
            page = 1,
            limit = 20,
            lambdaVal = 0.5 // For hybrid search: weight between embedding and BM25
        } = body;

        if (mode === 'embedding') {
            if (!query) {
                return NextResponse.json({ error: 'Query is required for embedding search.' }, { status: 400 });
            }
            const results = await performEmbeddingSearch(query, topK, weights, allShards, allRecords);
            return NextResponse.json(results);
        } else if (mode === 'hybrid') {
            if (!query) {
                return NextResponse.json({ error: 'Query is required for hybrid search.' }, { status: 400 });
            }
            const results = await performEmbeddingBM25Search(query, topK, weights, allShards, allRecords, lambdaVal);
            return NextResponse.json(results);
        } else if (mode === 'keyword') {
            const results = performKeywordSearch(query, page, limit, allRecords);
            return NextResponse.json(results);
        } else {
            return NextResponse.json({ error: `Invalid search mode: ${mode}` }, { status: 400 });
        }

    } catch (error) {
        console.error('Error in /api/search:', error);
        const errorMessage = error instanceof Error ? error.message : 'An unknown error occurred';
        return NextResponse.json({ error: 'An internal server error occurred.', details: errorMessage }, { status: 500 });
    }
}