import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import imageio
import torch
import os
import pandas as pd
import re
import random
from models.GANDataset import TensorDataset

import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision.utils import make_grid
from models.Generator import Generator
from tqdm import tqdm

import plot

mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman']
sns.set(style="darkgrid", context="paper", font_scale=1.2)

def extract_data_from_file(filename):
    with open(filename, 'r') as f:
        content = f.read()

    records = []
    current_hlen = None

    for line in content.splitlines():
        m_h = re.match(r'\s*History length:\s+(\d+)', line)
        if m_h:
            current_hlen = int(m_h.group(1))
            continue

        m_r = re.match(r'\s*Resid:\s+([-\d.eE+]+)\s+([-\d.eE+]+)', line)
        if m_r and current_hlen is not None:
            r1, r2 = float(m_r.group(1)), float(m_r.group(2))
            continue

        m_e = re.match(r'\s*Err:\s+([-\d.eE+]+)', line)
        if m_e and current_hlen is not None:
            e = np.power(10, float(m_e.group(1)))
            # Now we have a complete record:
            records.append((current_hlen, r1, r2, e))
            current_hlen = None  # reset for next block

    records = np.array(records)
    # Step 1: Subtract 1 from keys to make them zero-indexed
    keys = records[:, 0].astype(int) - 1
    data = records[:, 1:]

    # Step 2: Create an empty output array
    output = np.empty((20, 100, 3))

    # Step 3: Fill the output array by grouping based on keys
    for key in range(20):
        rows = data[keys == key]
        assert rows.shape[0] == 100, f"Key {key + 1} does not have exactly 100 rows"
        output[key] = rows

    df = pd.DataFrame({
        'Key': np.repeat(np.arange(1, 21), 100),
        'Sample': np.tile(np.arange(1, 101), 20),
        'L1': output[:, :, 2].flatten(),
        'L_infty': output[:, :, 0].flatten(),
    })
    return df

def plot_curves(name1, name2, y_col):

    df1 = extract_data_from_file(name1)
    df2 = extract_data_from_file(name2)

    df1['Dataset'] = 'Heat'
    df2['Dataset'] = 'Wave'

    combined = pd.concat([df1, df2], ignore_index=True)

    g = sns.lineplot(data=combined, x='Key', y=y_col, hue='Dataset', errorbar= ('sd', 36))
    g.set(yscale='log')
    if y_col == 'L1':
        g.set_ylabel("$L_1$ Error")
    else:
        g.set_ylabel("$L_\infty$ Error")
    g.set_xlabel("History size")
    g.set_xticks([0, 5, 10, 15, 20 ])
    plt.axvline(x=16, color='red', linestyle='--', linewidth=1)
    ymax = plt.ylim()[1]
    plt.text(15.9, ymax * 0.9, 'History size = compression factor', color='red', rotation=90, va='top', ha='right')
    plt.show()


def make_comparison_grid(real, fake, diff):
    real = torch.load(real)
    fake = torch.load(fake)
    diff = torch.load(diff)

    real1 = real[16]
    real100 = real[100]

    fake1 = fake[16]
    fake100 = fake[100]

    diff1 = diff[16]
    diff100 = diff[100]

    def tensor_to_numpy(t):
        """Convert (C, H, W) to (H, W, C) numpy array."""
        t = t.detach().cpu()
        if t.ndim == 3 and t.shape[0] == 3:
            return t.permute(1, 2, 0).numpy()
        elif t.ndim == 2:
            return t
        else:
            raise ValueError("Expected tensor with shape (C, H, W)")

    # Set up plot
    fig, axes = plt.subplots(2, 3, figsize=(14, 8))
    titles = [
        "Frame 1 – Real", "Frame 1 – Our Model", "Difference – Frame 1",
        "Frame 200 – Real", "Frame 200 – Our Model", "Difference – Frame 200"
    ]
    images = [
        real1, fake1, diff1,
        real100, fake100, diff100
    ]

    for ax, img, title in zip(axes.flat, images, titles):
        img_np = tensor_to_numpy(img)
        if img_np.ndim == 2:  # difference (grayscale)
            im = ax.imshow(img_np, cmap='magma')
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        else:
            ax.imshow(img_np)
        ax.set_title(title, fontsize=10)
        ax.axis("off")

    plt.tight_layout()
    plt.show()

def save_vids(solver, real, gen, diff, real_res, name, speed_up=1):
    gen = torch.load(gen).cpu()[::speed_up]
    real = torch.load(real).cpu()[::speed_up]
    diff = torch.load(diff).cpu()[::speed_up]
    real_res = torch.load(real_res).cpu()[::speed_up]

    plot.save_video(gen.detach().numpy(), f'data/generated/{solver.name}', f'gen_{name}', fps=60, adjust_clim=False)
    plot.save_video(real.detach().numpy(), f'data/generated/{solver.name}', f'real_{name}', fps=60, adjust_clim=False)
    plot.save_video(real_res.detach().numpy(), f'data/generated/{solver.name}', f'real_full_{name}', fps=60, adjust_clim=False)
    plot.save_video(diff.detach().numpy(), f'data/generated/{solver.name}', f'diff_{name}', fps=60, adjust_clim=True)


def plot_model_err(diff_heat, diff_wave, cutoff=100):
    diff_heat = torch.load(diff_heat)
    diff_wave = torch.load(diff_wave)

    norm_heat = np.mean(np.square(diff_heat[16:16 + cutoff].detach().numpy()), axis=(1, 2))
    norm_wave = np.mean(np.square(diff_wave[16:16 + cutoff].detach().numpy()), axis=(1, 2)) / 16  # Normalise

    frames = np.arange(len(norm_heat))

    df_heat = pd.DataFrame({'Frame': frames, 'Value': norm_heat})
    df_wave = pd.DataFrame({'Frame': frames, 'Value': norm_wave})

    fig, ax1 = plt.subplots()

    heat_colour = (0.2980392156862745, 0.4470588235294118, 0.6901960784313725)
    wave_colour = (0.8666666666666667, 0.5176470588235295, 0.3215686274509804) # Seaborn colours

    sns.lineplot(data=df_heat, x='Frame', y='Value', ax=ax1, color=heat_colour, legend=False, linewidth=0.5, label='Heat')
    ax1.set_ylabel('Heat (MSE)', color=heat_colour)
    ax1.set_xlabel('Frame')
    ax1.set_yscale('log')
    ax1.tick_params(axis='y', labelcolor=heat_colour)

    ax2 = ax1.twinx()
    sns.lineplot(data=df_wave, x='Frame', y='Value', ax=ax2, color=wave_colour, legend=False, linewidth=0.5, label='Wave')
    ax2.set_ylabel('Wave (MSE)', color=wave_colour)
    ax2.set_yscale('log')
    ax2.tick_params(axis='y', labelcolor=wave_colour)

    # Create custom legend
    lines = [
        plt.Line2D([], [], color=heat_colour, label='Heat'),
        plt.Line2D([], [], color=wave_colour, label='Wave')
    ]
    ax1.legend(handles=lines, loc='upper center', frameon=False)

    ax1.tick_params(axis='y', labelcolor=heat_colour)
    for label in ax1.get_yticklabels(minor=True):
        label.set_color(heat_colour)

    ax2.tick_params(axis='y', labelcolor=wave_colour)
    for label in ax2.get_yticklabels(minor=True):
        label.set_color(wave_colour)

    plt.tight_layout()
    plt.show()

def get_time_correlations_and_plot(data_path):
    def get_correlation(video, ns, h=280, w=800, plot_it=False):

        video = video.cpu().numpy()

        cors = []
        for n in ns:
            cors.append(np.corrcoef(video[0:-n, h, w], video[n:, h, w])[0][1])
        if plot_it:
            plt.plot(cors)
        return cors

    def load_and_process_tensors(path):
        real_cors = []
        gen_cors = []
        for file_name in os.listdir(path):
            if file_name.endswith('.pt'):
                tensor = torch.load(os.path.join(path, file_name))
                if 'real' in file_name:
                    real_cors.append(get_correlation(tensor, range(1, 300), h=128, w=128))
                else:
                   gen_cors.append(get_correlation(tensor, range(1, 300), h=128, w=128))
        return real_cors, gen_cors

    real_cors, gen_cors = load_and_process_tensors(data_path)

    df1 = pd.DataFrame(real_cors)  # Each row is a curve from path 1
    df1['Group'] = 'Real'

    df2 = pd.DataFrame(gen_cors)  # Each row is a curve from path 2
    df2['Group'] = 'Generated'

    df = pd.concat([df1, df2], ignore_index=True)

    df_long = pd.melt(df, id_vars='Group', var_name='Time', value_name='Correlation')
    sns.lineplot(data=df_long, x='Time', y='Correlation', hue='Group', linewidth=0.5, errorbar='sd')
    plt.show()

def plot_kse_errors(folder_path):
    def evaluate_model(model, dataloader, device):
        model.eval()
        l1_loss_fn = nn.L1Loss(reduction='mean')
        l2_loss_fn = nn.MSELoss(reduction='mean')

        l1_total = 0.0
        l2_total = 0.0
        count = 0
        norm_factor_l1 = 16 * 8
        norm_factor_l2 = (16 * 8) ** 2  # Normalise data to [-1, 1] (and thus residues)

        with torch.no_grad():
            for x_batch, y_batch in dataloader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                preds = model(x_batch)

                l1_total += l1_loss_fn(preds, y_batch).item()
                l2_total += l2_loss_fn(preds, y_batch).item()
                count += 1

        return l1_total / (count * norm_factor_l1), l2_total / (count * norm_factor_l2)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = TensorDataset('data/kse/test', 4, 16, device)

    total_size = len(dataset)

    # Random subsample of indices
    indices = random.sample(range(total_size), min(200 * 1, total_size))
    subset = Subset(dataset, indices)

    dataloader = DataLoader(subset, batch_size=768, shuffle=False)

    records = []

    for filename in tqdm(sorted(os.listdir(folder_path))):
        if filename.endswith(".pt"):
            model = Generator(1, 1).to(dataset.device)
            model.load('kse-tok-256-16/' + filename.removesuffix(".pt"), 'gan')
            if hasattr(model, 'to'):
                model.to(device)
            else:
                raise ValueError(f"Model {filename} cannot be loaded properly.")

            l1, l2 = evaluate_model(model, dataloader, device)
            records.append({'Model': int(filename.removesuffix(".pt")), 'Average L1 Loss': l1, 'Average L2 Loss': l2})

    df = pd.DataFrame(records)

    # Melt DataFrame for Seaborn
    df_melted = df.melt(id_vars="Model", value_vars=["Average L1 Loss", "Average L2 Loss"],
                        var_name="Loss Type", value_name="Loss Value")

    # Plot
    ax = sns.lineplot(data=df_melted, x="Model", y="Loss Value", hue="Loss Type", marker='o', linewidth=0.5)

    # Minima and vertical lines
    for loss_type in ["Average L1 Loss", "Average L2 Loss"]:
        subset = df[df[loss_type] == df[loss_type].min()]
        min_model = subset["Model"].values[0]
        min_value = subset[loss_type].values[0]
        ymax = plt.ylim()[1]
        if 'L1' in loss_type:
            color = (0.2980392156862745, 0.4470588235294118, 0.6901960784313725)
            plt.text(min_model - 0.2, ymax * 0.9, f"Min: {min_value:.4f}", color=color, rotation=90,
                     va='top', ha='right')
        else:
            color = (0.8666666666666667, 0.5176470588235295, 0.3215686274509804) # Seaborn colours
            plt.text(min_model - 0.2, ymax * 0.9, f"Min: {min_value:.6f}", color=color, rotation=90,
                     va='top', ha='right')
        plt.axvline(x=min_model, linestyle='--', color=color, linewidth=0.5, alpha=0.6, label=f"Min {loss_type}: {min_value:.4f}")

    ax.set_xlabel("Model Epoch")
    ax.set_ylabel("Loss")
    ax.set_xticks([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70])
    ax.set_yscale('log')
    plt.savefig("kse_l1_l2_test_residue.png", dpi=300)

def load_and_plot_multi_run(csv_path, x_col='test/step', metric_suffix='test/loss'):
    if not os.path.isfile(csv_path):
        raise FileNotFoundError(f"CSV file not found: {csv_path}")

    df = pd.read_csv(csv_path)

    if x_col not in df.columns:
        raise ValueError(f"x-axis column '{x_col}' not found in CSV.")

    # Identify metric columns for all runs
    metric_cols = [col for col in df.columns if col != x_col]
    if metric_suffix:
        metric_cols = [col for col in metric_cols if col.endswith(metric_suffix)]
    if not metric_cols:
        raise ValueError("No metric columns found matching the expected pattern.")

    # Melt the DataFrame to long format

    run_order = ['History size 2', 'History size 6', 'History size 10', 'History size 14', 'History size 16', 'History size 20', 'Unmodified heat operator']

    df_melted = df.melt(id_vars=[x_col], value_vars=metric_cols,
                        var_name='run', value_name='metric')
    df_melted['run'] = df_melted['run'].str.removesuffix(' - test/loss')
    cat_type = pd.CategoricalDtype(run_order, ordered=True)
    df_melted['run'] = df_melted['run'].astype(cat_type)
    df_melted['metric'] /= 16 ** 2 # Data normalisation factor to [-1, 1]

    # dashed_lines = {'History size 16', 'History size 20', 'Unmodified heat operator'}

    # Assign line styles
    # df_melted['line_style'] = df_melted['run'].apply(
    #     lambda r: 'dashed' if r in dashed_lines else 'solid'
    # )

    ax = sns.lineplot(data=df_melted, x=x_col, y='metric', hue='run', marker='o', linewidth=0.5, errorbar=None)

    ax.set_xlabel("Model Epoch")
    ax.set_ylabel("Average $L_2$ Loss (Test set)")
    ax.set_xticks([0, 5, 10, 15, 20, 25])
    ax.set_yscale('log')
    plt.legend(title='Run')
    plt.show()