import matplotlib.pyplot as plt
import numpy as np
import json
import re
import sys
from pathlib import Path
from dataclasses import dataclass
from itertools import product
from string import ascii_uppercase, ascii_lowercase
from matplotlib.ticker import FuncFormatter, MaxNLocator
from matplotlib.colors import PowerNorm

@dataclass
class TrainingResult:
    variant: str
    base_params: str
    epoch: int
    overall_solve_rate: float
    embedding_dim: int
    layers: int
    heads: int
    inner_mult: int

def count_model_params(E: int, L: int, H: int, IM: int, vocab_size: int = 120, n_positions: int = 128) -> int:
    """Calculate model parameters using GPT2 formula from train_models.py"""
    # Token + position embeddings
    embedding_params = vocab_size * E + n_positions * E
    
    # Transformer blocks: attention + MLP
    n_inner = IM * E
    attn_params_per_layer = E * E * 4  # Q, K, V, O projections
    mlp_params_per_layer = E * n_inner + n_inner * E  # up + down projections
    block_params = L * (attn_params_per_layer + mlp_params_per_layer)
    
    # Final layer norm + output head
    ln_params = E
    output_head_params = E * vocab_size
    
    return embedding_params + block_params + ln_params + output_head_params

def _allowed_variant_names() -> set[str]:
    sys.path.append(str(Path(__file__).resolve().parents[1]))
    from train.train_models import VARIANTS  # type: ignore
    return {name for name, *_ in VARIANTS}

def _build_stoi(upper_grams: tuple[int, ...] = (3,)) -> dict[str, int]:
    specials = ["<pad>", "<unk>", "<eos>"]
    punct = ["-", "[", "]", "(", ")", ":"]
    up = ["".join(p) for k in upper_grams for p in product(ascii_uppercase, repeat=k)]
    low = list(ascii_lowercase)
    vocab = specials + punct + up + low
    return {t: i for i, t in enumerate(vocab)}

def _split_query_path(s: str) -> tuple[str, str]:
    i = s.index(":"); return s[:i], s[i+1:]

def _parse_qmeta(q: str) -> tuple[str, str, str, str]:
    b1, b2 = q.find("["), q.find("(")
    cut = len(q) if b1 < 0 else b1
    if b2 >= 0: cut = min(cut, b2)
    u, v = q[:cut].split("-", 1)
    inc = ""; exc = ""
    if b1 >= 0:
        e1 = q.index("]", b1); inc = q[b1+1:e1]
    if b2 >= 0:
        e2 = q.index(")", b2); exc = q[b2+1:e2]
    return u, v, inc, exc

def _tokenize_query(q: str) -> list[str]:
    u, v, inc, exc = _parse_qmeta(q)
    toks: list[str] = [u, "-", v]
    if inc: toks += ["[", *list(inc), "]"]
    if exc: toks += ["(", *list(exc), ")"]
    toks.append(":"); return toks

def _tokenize_path(p: str, upper_grams: tuple[int, ...] = (3,)) -> list[str]:
    toks: list[str] = []; i, n = 0, len(p)
    while i < n:
        j = i
        while j < n and p[j] in ascii_uppercase: j += 1
        node = p[i:j]; i = j
        if len(node) == 0 or len(node) not in upper_grams: raise ValueError(f"bad node: {node}")
        toks.append(node)
        if i >= n: break
        l = p[i]
        if l not in ascii_lowercase: raise ValueError("expected lowercase label")
        toks.append(l); i += 1
    return toks

def _count_tokens_in_line(line: str, upper_grams: tuple[int, ...] = (3,)) -> int:
    line = line.strip(); core = line.split(" #d=")[0]
    q, p = _split_query_path(core)
    q_toks = _tokenize_query(q); p_toks = _tokenize_path(p, upper_grams)
    toks = q_toks + p_toks + ["<eos>"]
    return len(toks)

def _total_train_tokens(train_path: str = "data/paths_paper/train.txt") -> int:
    rows = [ln for ln in Path(train_path).read_text(encoding="utf-8").splitlines() if ln.strip()]
    return sum(_count_tokens_in_line(ln) for ln in rows)

def _flops_per_epoch(bucket: str, total_tokens: int) -> float:
    n_params = {"1M": 1_000_000, "10M": 10_000_000, "100M": 100_000_000}[bucket]
    return 6.0 * float(n_params) * float(total_tokens)

def _sci(x: float) -> str: return f"{x:.1e}"

def parse_variant_name(variant: str) -> tuple[str, int, int, int, int]:
    """Parse variant name like '10M_E192_L20_H1_IM4' to extract parameters"""
    match = re.match(r'(\w+)_E(\d+)_L(\d+)_H(\d+)_IM(\d+)', variant)
    if not match: raise ValueError(f"Cannot parse variant name: {variant}")
    
    base_size, E, L, H, IM = match.groups()
    return base_size, int(E), int(L), int(H), int(IM)

def read_training_results(results_dir: str = "out/paths_paper_final") -> list[TrainingResult]:
    """Read training results from paths_paper_final directory"""
    results_path = Path(results_dir)
    if not results_path.exists(): raise FileNotFoundError(f"Results directory not found: {results_dir}")
    
    results: list[TrainingResult] = []
    allowed = _allowed_variant_names()
    
    for variant_dir in results_path.iterdir():
        if not variant_dir.is_dir() or variant_dir.name == "vocab.json": continue
        
        variant_name = variant_dir.name
        if variant_name not in allowed: continue
        try:
            base_size, E, L, H, IM = parse_variant_name(variant_name)
            param_count = count_model_params(E, L, H, IM)
        except ValueError: continue
        
        # Find all epoch solve files
        epoch_files = list(variant_dir.glob("epoch_*_solve.json"))
        
        for epoch_file in epoch_files:
            try:
                epoch_match = re.search(r'epoch_(\d+)_solve\.json', epoch_file.name)
                if not epoch_match: continue
                
                epoch = int(epoch_match.group(1))
                
                with open(epoch_file, 'r') as f:
                    data = json.load(f)
                
                solve_rate = data["overall"]["solve_rate"]
                
                results.append(TrainingResult(
                    variant=variant_name,
                    base_params=base_size,
                    epoch=epoch,
                    overall_solve_rate=solve_rate,
                    embedding_dim=E,
                    layers=L,
                    heads=H,
                    inner_mult=IM
                ))
                
            except (json.JSONDecodeError, KeyError, ValueError) as e:
                print(f"Error reading {epoch_file}: {e}")
                continue
    
    return sorted(results, key=lambda x: (x.base_params, x.variant, x.epoch))

def reformat_training_data(results: list[TrainingResult]) -> list[dict]:
    """Reformat training results to (base_params, e/l_ratio, avg_solve_rate, epoch) format"""
    from collections import defaultdict
    
    # Group results by (base_params, e/l_ratio, epoch)
    grouped = defaultdict(list)
    
    for r in results:
        e_l_ratio = r.embedding_dim / r.layers
        key = (r.base_params, e_l_ratio, r.epoch)
        grouped[key].append(r.overall_solve_rate)
    
    # Calculate averages and create formatted output
    formatted_data = []
    for (base_params, e_l_ratio, epoch), solve_rates in grouped.items():
        avg_solve_rate = sum(solve_rates) / len(solve_rates)
        formatted_data.append({
            'base_params': base_params,
            'e/l_ratio': e_l_ratio,
            'avg_solve_rate': avg_solve_rate,
            'epoch': epoch
        })
    
    return sorted(formatted_data, key=lambda x: (x['base_params'], x['e/l_ratio'], x['epoch']))

def generate_sample_data():
    """Generate sample data for the three subplots"""
    x = np.linspace(0, 10, 100)
    y = np.linspace(0, 10, 100)
    X, Y = np.meshgrid(x, y)
    
    # Different data patterns for each subplot
    Z1 = np.sin(X) * np.cos(Y)
    Z2 = np.exp(-((X-5)**2 + (Y-5)**2)/10)
    Z3 = np.sin(X/2) * np.sin(Y/2) + np.cos(X) * np.cos(Y)
    
    return Z1, Z2, Z3

def create_inferno_plot(reformatted_data):
    """Create the plot with inferno colormap using real training data"""
    from scipy.interpolate import griddata
    
    # Separate data by base_params
    data_by_params = {'1M': [], '10M': [], '100M': []}
    
    for entry in reformatted_data:
        base = entry['base_params']
        if base in data_by_params:
            data_by_params[base].append(entry)
    
    # Create figure with 3 horizontal subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5), dpi=600)
    
    # Prepare data for each subplot
    def prepare_grid_data(data_list):
        if not data_list: return None, None, None, None, None
        
        x = np.array([d['e/l_ratio'] for d in data_list])
        x = np.clip(x, 0.0, 1000.0)
        y = np.array([d['epoch'] for d in data_list])
        z = np.array([d['avg_solve_rate'] for d in data_list])
        
        # Create grid for interpolation
        x_min, x_max = 0.0, 1000.0
        y_min, y_max = y.min(), y.max()
        
        # Create regular grid
        xi = np.linspace(x_min, x_max, 50)
        yi = np.linspace(y_min, y_max, 50)
        Xi, Yi = np.meshgrid(xi, yi)
        
        # Interpolate data onto grid using linear to avoid overshooting
        Zi = griddata((x, y), z, (Xi, Yi), method='linear', fill_value=np.nan)
        
        # Clamp values to valid range [0, 1] for solve rates
        Zi = np.clip(Zi, 0, 1)
        
        return Xi, Yi, Zi, [x_min, x_max, y_min, y_max], (x, y, z)
    
    # Prepare data for each model size
    X1, Y1, Z1, extent1, raw1 = prepare_grid_data(data_by_params['1M'])
    X2, Y2, Z2, extent2, raw2 = prepare_grid_data(data_by_params['10M'])
    X3, Y3, Z3, extent3, raw3 = prepare_grid_data(data_by_params['100M'])
    
    # Find global min/max for consistent colorbar
    all_z_values = []
    if Z1 is not None: all_z_values.extend(Z1.flatten()[~np.isnan(Z1.flatten())])
    if Z2 is not None: all_z_values.extend(Z2.flatten()[~np.isnan(Z2.flatten())])
    if Z3 is not None: all_z_values.extend(Z3.flatten()[~np.isnan(Z3.flatten())])
    
    if all_z_values:
        vmin = min(all_z_values)
        vmax = max(all_z_values)
    else:
        vmin, vmax = 0, 1

    # Create banded contour plots for each subplot with consistent limits
    # Define contour levels for banded appearance (more bands)
    n_levels = 24
    levels = np.linspace(vmin, vmax, n_levels)
    # Use a gentle power-law norm (gamma>1 emphasizes upper range; milder than log)
    norm = PowerNorm(gamma=1.5, vmin=vmin, vmax=vmax)
    
    im1, im2, im3 = None, None, None
    
    if Z1 is not None:
        X1_grid, Y1_grid = np.meshgrid(np.linspace(extent1[0], extent1[1], Z1.shape[1]),
                                       np.linspace(extent1[2], extent1[3], Z1.shape[0]))
        im1 = ax1.contourf(X1_grid, Y1_grid, Z1, levels=levels, cmap='inferno', 
                          norm=norm, extend='neither')
    
    if Z2 is not None:
        X2_grid, Y2_grid = np.meshgrid(np.linspace(extent2[0], extent2[1], Z2.shape[1]),
                                       np.linspace(extent2[2], extent2[3], Z2.shape[0]))
        im2 = ax2.contourf(X2_grid, Y2_grid, Z2, levels=levels, cmap='inferno', 
                          norm=norm, extend='neither')
    
    if Z3 is not None:
        X3_grid, Y3_grid = np.meshgrid(np.linspace(extent3[0], extent3[1], Z3.shape[1]),
                                       np.linspace(extent3[2], extent3[3], Z3.shape[0]))
        im3 = ax3.contourf(X3_grid, Y3_grid, Z3, levels=levels, cmap='inferno', 
                          norm=norm, extend='neither')

    # Set titles
    ax1.set_title('1M', fontsize=14)
    ax2.set_title('10M', fontsize=14)
    ax3.set_title('100M', fontsize=14)
    
    # Map epoch -> FLOPs on y-axis: show scaled numbers with a small exponent note
    total_tokens = _total_train_tokens("data/paths_paper/train.txt")
    y_maps = {
        '1M': _flops_per_epoch('1M', total_tokens),
        '10M': _flops_per_epoch('10M', total_tokens),
        '100M': _flops_per_epoch('100M', total_tokens),
    }
    extents = {'1M': extent1, '10M': extent2, '100M': extent3}
    fixed_exp = {'1M': 15, '10M': 16, '100M': 17}
    def make_fmt(mult: float, scale: float):
        return FuncFormatter(lambda e, _pos: f"{(max(0.0, e)*mult)/scale:.2f}")
    for ax, label in ((ax1, '1M'), (ax2, '10M'), (ax3, '100M')):
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))
        ext = extents[label]
        if ext is None: continue
        e3 = fixed_exp[label]
        scale = 10.0 ** e3
        ax.yaxis.set_major_formatter(make_fmt(y_maps[label], scale))
        ax.text(0.0, 1.02, f"$\\times 10^{{{e3}}}$", transform=ax.transAxes, ha='left', va='bottom', fontsize=10, clip_on=False)

    # Set tick label sizes
    for ax in [ax1, ax2, ax3]:
        ax.tick_params(axis='both', which='major', labelsize=14)
    ax1.set_xlim(0, 700)
    ax2.set_xlim(0, 700)
    ax3.set_xlim(0, 700)

    # Adjust layout to make space for colorbar and labels
    plt.subplots_adjust(left=0.1, right=0.85, bottom=0.15)
    
    # Add universal axis labels (centered on the subplot area)
    fig.text(0.475, 0.04, 'Width to Depth Tradeoff (E/L Ratio)', ha='center', fontsize=14)
    fig.text(0.04, 0.5, 'Compute (FLOPs)', va='center', rotation='vertical', fontsize=14)
    
    # Add single colorbar on the right side (use first available contour plot)
    cbar_ax = fig.add_axes([0.87, 0.25, 0.015, 0.5])
    im_for_cbar = None
    if im1 is not None: im_for_cbar = im1
    elif im2 is not None: im_for_cbar = im2  
    elif im3 is not None: im_for_cbar = im3
    
    if im_for_cbar is not None:
        cbar = plt.colorbar(im_for_cbar, cax=cbar_ax)
        cbar.set_label('Creativity (Utility x Novelty)', rotation=270, labelpad=20, fontsize=14)
        # Reduce number of ticks on colorbar
        cbar.ax.tick_params(labelsize=14)
        cbar.locator = plt.MaxNLocator(nbins=7)
        cbar.update_ticks()
    
    # Ensure plots directory exists
    plots_dir = Path('plots')
    plots_dir.mkdir(exist_ok=True)
    
    # Save in both PNG and PDF formats
    png_path = plots_dir / 'creativity_plot.png'
    pdf_path = plots_dir / 'creativity_plot.pdf'
    
    plt.savefig(png_path, dpi=600, bbox_inches='tight')
    plt.savefig(pdf_path, dpi=600, bbox_inches='tight')
    plt.show()
    
    print(f"Plot saved to {png_path}")
    print(f"Plot saved to {pdf_path}")

if __name__ == "__main__":
    # Read and process training results
    try:
        results = read_training_results()
        print(f"Read {len(results)} training results")
        
        if results:
            print("\nSample raw results:")
            for r in results[:5]:
                print(f"  {r.variant}: epoch {r.epoch}, solve_rate={r.overall_solve_rate:.3f}")
            
            # Reformat data for plotting
            reformatted = reformat_training_data(results)
            print(f"\nReformatted to {len(reformatted)} entries")
            print("\nSample reformatted data (base_params, e/l_ratio, avg_solve_rate, epoch):")
            for data in reformatted[:10]:
                print(f"  {data['base_params']}, {data['e/l_ratio']:.3f}, {data['avg_solve_rate']:.3f}, {data['epoch']}")
            
            # Create plot with real data
            print("\nCreating inferno plot with real training data...")
            create_inferno_plot(reformatted)
        else:
            print("No training results found. Using sample data instead.")
            data = generate_sample_data()
            # Can't use sample data with new function signature, would need to convert
            print("Sample data generation not compatible with new plot function.")
            
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
