import os
import time
import numpy as np

import sys
sys.path.insert(0, '/raid/home/q615005/xfac/build/python')
import xfacpy
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

from stylesheet import *
from hamiltonian_xy_transverse import get_gap_oracle
from config_loader import cfg


# ----------------------------
# 1) TCI Resolution Benchmark
# ----------------------------
def run_tci_resolution_benchmark(
    n_bit_values=[6, 9, 12],
    bond_dim=300,
    rel_tol=1e-10,
    base_res_dir="results/tci_resolution"
):
    """
    Run TCI at different resolutions and save results for Figure 4.

    Args:
        n_bit_values: List of quantics bit resolutions to test
        bond_dim: Bond dimension for TCI
        rel_tol: Relative tolerance for TCI convergence
        base_res_dir: Base directory for results
    """
    print("=" * 70)
    print("STARTING TCI RESOLUTION BENCHMARK (Figure 4)")
    print("=" * 70)

    # Build Oracle
    gap_fn = get_gap_oracle()

    results = {}

    for N_BIT in n_bit_values:
        print(f"\n[TCI Resolution] Running with {N_BIT} bits...")

        # Setup results directory
        res_dir = f"{base_res_dir}_chain_{N_BIT}bits"
        os.makedirs(res_dir, exist_ok=True)

        # Quantics Grid (dim=2 for J and h)
        qgrid = xfacpy.QuanticsGrid(a=0.0, b=1.0, dim=2, nBit=N_BIT)

        def fcaching(arg):
            a = np.asarray(arg, dtype=float)
            key = tuple(np.round(a, 14))
            # We only need the gap for the TCI construction
            val = gap_fn(a)
            y = float(val)
            fcaching.sampled[key] = y
            return y

        fcaching.sampled = {}

        # Initialize TCI
        args_tci = xfacpy.TensorCI2Param()
        args_tci.bondDim = bond_dim
        args_tci.reltol = rel_tol
        args_tci.pivot1 = [1] * (N_BIT * 2)
        args_tci.ncheckhistory = 10

        ci = xfacpy.QTensorCI(f=fcaching, qgrid=qgrid, args=args_tci)

        print(f"[TCI Resolution] Starting TCI for {N_BIT}-bit")
        start_time = time.perf_counter()

        function_calls_dict = {}

        while not ci.isDone():
            ci.iterate()
            function_calls_dict[len(ci.pivotError)] = len(fcaching.sampled)

        duration = time.perf_counter() - start_time
        print(f"[TCI Resolution] {N_BIT}-bit completed in {duration:.2f}s with {len(fcaching.sampled)} samples.")

        # Save Results
        qtt = ci.get_qtt()
        qtt.save(os.path.join(res_dir, "xy_chain.xfac"))
        qtt.save(os.path.join(res_dir, "qtt.xfac"))

        bond_dimensions = [site.shape[-1] for site in qtt.tt.core[:-1]]

        np.save(os.path.join(res_dir, "sampled_points.npy"), fcaching.sampled, allow_pickle=True)
        np.save(os.path.join(res_dir, "bond_dimensions.npy"), bond_dimensions)
        np.save(os.path.join(res_dir, "pivot_errors.npy"), np.array(ci.pivotError))
        np.save(os.path.join(res_dir, "function_calls.npy"), function_calls_dict, allow_pickle=True)

        results[N_BIT] = {
            'bond_dimensions': bond_dimensions,
            'pivot_errors': np.array(ci.pivotError),
            'function_calls': function_calls_dict,
            'total_samples': len(fcaching.sampled),
            'duration': duration,
            'res_dir': res_dir
        }

        print(f"[TCI Resolution] {N_BIT}-bit results saved to {res_dir}")

    return results


# ----------------------------
# 2) Figure 4: TCI Resolution and Bond Dimensions Analysis
# ----------------------------
def create_tci_resolution_figure(
    results_dict,
    n_bit_values=[6, 9, 12],
    base_res_dir="results/tci_resolution",
    output_path="results/convergence_and_bonds.pdf",
    eps_target=1e-10
):
    """
    Create Figure 4 showing TCI convergence and bond dimensions.

    This creates a side-by-side comparison of:
    - Left panel: TCI convergence (error vs iterations)
    - Right panel: Bond dimensions comparison across resolutions
    """
    print("[Figure 4] Creating TCI resolution and bond dimensions plot")

    # Load data from results or from saved files
    bd_data = {}
    pivot_errors = None
    main_run_dir = None

    for n_bit in n_bit_values:
        if n_bit in results_dict:
            bd_data[n_bit] = results_dict[n_bit]['bond_dimensions']
            if n_bit == max(n_bit_values):
                pivot_errors = results_dict[n_bit]['pivot_errors']
                main_run_dir = results_dict[n_bit]['res_dir']
        else:
            # Fallback: try to load from files
            res_dir = f"{base_res_dir}_chain_{n_bit}bits"
            try:
                bd_data[n_bit] = np.load(os.path.join(res_dir, "bond_dimensions.npy"))
                if n_bit == max(n_bit_values):
                    pivot_errors = np.load(os.path.join(res_dir, "pivot_errors.npy"))
                    main_run_dir = res_dir
                print(f"[Figure 4] Loaded data for {n_bit}-bit from {res_dir}")
            except FileNotFoundError:
                print(f"[Figure 4] Warning: Could not load data for {n_bit}-bit resolution")
                continue

    if pivot_errors is None:
        print("[Figure 4] Error: No convergence data available")
        return None

    # ==========================================
    # Create Figure 4 (Side-by-Side)
    # ==========================================
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12.8, 4.8))

    # -----------------------------------------------------------
    # PANEL LEFT: Convergence (Log10 scale)
    # -----------------------------------------------------------
    if len(pivot_errors) > 1:
        relative_error = pivot_errors / pivot_errors[0]
        sweeps = np.arange(len(relative_error))

        ax1.plot(sweeps, relative_error, marker='o', markersize=4, linestyle='-',
                color=color_palette['blue'], label='Relative error')

        ax1.axhline(y=eps_target, color='black', linestyle=':', alpha=0.8,
                   label='Target tolerance')

        ax1.set_yscale("log")
        ax1.set_xlabel("Number of half sweeps")
        ax1.set_ylabel("Relative in-sample error")
        ax1.set_title("(a) TCI Convergence")
        ax1.set_ylim(eps_target * 0.1, 2)
        ax1.grid(True, which="both", alpha=0.2)
        ax1.legend()

    # -----------------------------------------------------------
    # PANEL RIGHT: Bond Dimensions Comparison (Log2 scale)
    # -----------------------------------------------------------
    L_max = n_bit_values[-1] * 2
    l_ref = np.arange(1, L_max)

    # Base-2 reference lines
    ax2.plot(l_ref, 2.0**l_ref, color='gray', linewidth=line_width,
            label='Full Rank $2^\\ell$', linestyle='--')
    ax2.plot(l_ref, 2.0**(L_max - l_ref), color='gray',
            linewidth=line_width, linestyle='--')

    # Plot bond dimensions for each resolution
    colors = [color_palette['red'], color_palette['orange'], color_palette['blue']]
    for i, n_bit in enumerate(n_bit_values):
        if n_bit in bd_data:
            bd = bd_data[n_bit]
            R = n_bit * 2  # Total number of qubits
            ax2.plot(np.arange(1, len(bd)+1), bd[::-1],
                    color=colors[i % len(colors)], ls='-', lw=line_width,
                    label=f"R = {R} ({n_bit}-bit)")

    # Configure log2 scale
    ax2.set_yscale('log', base=2)
    ax2.yaxis.set_major_formatter(mticker.ScalarFormatter())
    ax2.yaxis.set_major_locator(mticker.LogLocator(base=2.0))

    ax2.set_xlabel(r"Bond Index $\ell$")
    ax2.set_ylabel(r"Bond Dimension $\chi_\ell$")
    ax2.set_xlim(0, L_max)
    ax2.set_yticks(np.logspace(0, 9, num=10, base=2))
    ax2.set_ylim(1, 64)
    ax2.set_title("(b) QTT Bond Dimensions")
    ax2.legend(loc='upper right')
    ax2.grid(True, which="both", axis='both', alpha=0.2)

    # ==========================================
    # Finalize
    # ==========================================
    plt.tight_layout()

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    print(f"[Figure 4] Saving TCI resolution figure to {output_path}")
    plt.savefig(output_path)

    return fig


# ----------------------------
# 3) Main Function
# ----------------------------
def main():
    """
    Main function to run TCI resolution benchmark and create Figure 4.
    """
    print("=" * 70)
    print("STARTING TCI RESOLUTION BENCHMARK (Figure 4)")
    print("=" * 70)

    # Load configuration
    config = cfg()
    N_BIT_VALUES = config['N_BIT_VALUES']
    BOND_DIM = config['BOND_DIM']
    REL_TOL = config['REL_TOL_BENCHMARK']
    BASE_RES_DIR = config['RES_DIR']

    # Run TCI benchmark at different resolutions
    print("[Main] Running TCI resolution benchmark...")
    results = run_tci_resolution_benchmark(
        n_bit_values=N_BIT_VALUES,
        bond_dim=BOND_DIM,
        rel_tol=REL_TOL,
        base_res_dir=f"{BASE_RES_DIR}/tci_resolution"
    )

    # Create Figure 4
    print("\n[Main] Creating Figure 4...")
    fig = create_tci_resolution_figure(
        results_dict=results,
        n_bit_values=N_BIT_VALUES,
        eps_target=REL_TOL,
        base_res_dir=f"{BASE_RES_DIR}/tci_resolution"
    )

    print("[Main] TCI resolution benchmark and Figure 4 completed successfully")
    return results, fig


if __name__ == "__main__":
    main()
