import os
import random
import numpy as np
import torch
from models.triplepinn import TriplePINN
from models.quadpinn import QuadPINN
from models.pinn import PINN
from models.ncl import NCL
from models.dualpinn import DualPINN
from models.dualpinnncl import DualPINN_NCL
from models.triplepinnncl import TriplePINN_NCL
from params import *

def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
def get_model(model_config:dict):
    model_name = model_config["model_name"]
    if model_name == "MP-3xPINN":
        return TriplePINN(**model_config)
    elif model_name == "MP-2xPINN":
        return DualPINN(**model_config)
    elif model_name == "PINN":
        return PINN(**model_config)
    elif model_name == "NCL":
        return NCL(**model_config)
    elif model_name == "MP-4xPINN":
        return QuadPINN(**model_config)
    elif model_name == "MP-2xPINN+NCL":
        return DualPINN_NCL(**model_config)
    elif model_name == "MP-3xPINN+NCL":
        return TriplePINN_NCL(**model_config)
    else:
        raise ValueError(f"Unknown model name: {model_name}")


def get_data():
    try:
        sol_dataset = torch.load('data/sol.pt', weights_only=False)
        ic_dataset = torch.load('data/initial_condition.pt', weights_only=False)
    except FileNotFoundError as e:
        print(f"Error loading datasets: {e}, run the generate.py script.")

    return sol_dataset, ic_dataset


def generate_random_points(n_points):
    pts = torch.rand(n_points, 3)
    pts[:, 0] = pts[:, 0] * (t_max - t_min) + t_min
    pts[:, 1] = pts[:, 1] * (x_max - x_min) + x_min
    pts[:, 2] = pts[:, 2] * (y_max - y_min) + y_min
    return pts


# Compute gradients for each loss component and their norms
def grad_collect(loss_fn, optimizer, model, retain_graph, **kwargs):
    optimizer.zero_grad()
    loss = loss_fn(**kwargs)
    loss = loss.sum() if loss.ndim > 0 else loss
    loss.backward(retain_graph=retain_graph)
    grads_per_layer = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grads_per_layer[name] = param.grad.detach().cpu().numpy().copy()
        else:
            grads_per_layer[name] = None
            
    return grads_per_layer


def collect_grad_by_name(model, func, net_prefix, grad_dict, optimizer, retain_graph=True, **kwargs):
    grads_per_layer = grad_collect(func, optimizer, model, retain_graph=retain_graph, **kwargs)
    for name, grad in grads_per_layer.items():
        if name.startswith(net_prefix) and grad is not None:
            if name not in grad_dict:
                grad_dict[name] = []
            grad_dict[name].append(grad.flatten())

def eigvals_from_output(output, net_params, max_samples=100):
    # Limit number of samples for memory efficiency
    n_samples = min(output.shape[0], max_samples)
    grads_list = []
    for i in range(n_samples):
        out_i = output[i]
        if out_i.ndim == 0:
            grad = torch.autograd.grad(
                outputs=out_i,
                inputs=net_params,
                retain_graph=True,
                create_graph=False,  # Don't build graph for higher-order grads
                allow_unused=True
            )
            grad_flat = torch.cat([g.flatten().detach().cpu() for g in grad if g is not None])
            grads_list.append(grad_flat)
        else:
            # If output[i] has more than one component, compute grad for each component
            for j in range(out_i.numel()):
                grad = torch.autograd.grad(
                    outputs=out_i.flatten()[j],
                    inputs=net_params,
                    retain_graph=True,
                    create_graph=False,
                    allow_unused=True
                )
                grad_flat = torch.cat([g.flatten().detach().cpu() for g in grad if g is not None])
                grads_list.append(grad_flat)
    grads_mat = torch.stack(grads_list)
    # Compute Gram matrix (K) efficiently
    K = grads_mat @ grads_mat.T
    eigs = torch.linalg.eigvals(K).real
    eigs = torch.clamp(eigs, min=0).cpu().numpy()
    return np.sort(eigs)[::-1]


def plot_eigvals(eigvals:dict, save_dir:str, step:int):
    import matplotlib.pyplot as plt
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # Determine objectives from eigvals dict
    objectives = list(next(iter(eigvals.values())).keys()) if eigvals else []
    fig, axs = plt.subplots(len(eigvals), len(objectives), figsize=(4 * len(objectives), 3 * len(eigvals)), squeeze=False)
    for i, net_prefix in enumerate(eigvals.keys()):
        for j, objective in enumerate(objectives):
            vals = eigvals[net_prefix].get(objective, [])[-1]
            vals = np.array(vals)
            axs[i][j].plot(range(len(vals)), vals, marker='o')
            axs[i][j].set_title(f"{net_prefix} - {objective}")
            axs[i][j].set_xlabel("Index")
            axs[i][j].set_ylabel("Eigenvalue")
            axs[i][j].set_ylim(bottom=1e-6)  # Ensure y-axis starts at 0
            axs[i][j].set_yscale("log")
            axs[i][j].grid(True)
    fig.tight_layout()
    fig.savefig(os.path.join(save_dir, f'eigvals_step_{step}.png'))
    plt.close()

def plot_grad_hists(grad_values, save_dir:str, step:int):
    import matplotlib.pyplot as plt
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    import matplotlib.pyplot as plt

    # Collect all gradient values to determine global x and y limits
    all_grads = []
    for losses_dict in grad_values.values():
        for layer_dict in losses_dict.values():
            for grads in layer_dict.values():
                if grads:
                    all_grads.extend(grads[-1])
    if all_grads:
        y_min, y_max = 1, max(np.histogram(all_grads, bins=100)[0])  # Avoid log(0)
        x_min, x_max = min(all_grads), max(all_grads)
    else:
        y_min, y_max, x_min, x_max = 1, 10, -1, 1  # Fallback values
        
    # For each net_prefix, only plot layers that start with that net_prefix
    net_layers = {net_prefix: sorted([layer for layer in {layer for loss_dict in losses_dict.values() for layer in loss_dict}
                                      if layer.startswith(net_prefix)])
                  for net_prefix, losses_dict in grad_values.items()}
    num_nets = len(grad_values)
    num_layers_per_net = [len(layers) for layers in net_layers.values()]
    max_layers = max(num_layers_per_net)
    bins = np.linspace(x_min, x_max, 100)

    # Define colors for each loss
    loss_colors = {
        "bc": "tab:red",
        "mom": "tab:blue",
        "div": "tab:green",
        "alignment": "tab:purple",
        "ic": "yellow"
        # Add more loss names and colors as needed
    }

    # Plot gradient histograms for each layer (all params)
    fig, axs = plt.subplots(num_nets, max_layers, figsize=(4 * max_layers, 3 * num_nets), squeeze=False)
    for row, (net_prefix, losses_dict) in enumerate(grad_values.items()):
        layers = net_layers[net_prefix]
        for col, layer in enumerate(layers):
            handles = []
            labels = []
            for loss_name, layer_dict in losses_dict.items():
                grads = layer_dict.get(layer, [])
                if grads:
                    color = loss_colors.get(loss_name, None)
                    hist_handle = axs[row, col].hist(grads[-1], bins=bins, alpha=0.3, label=loss_name, color=color)
                    #counts, bin_edges = np.histogram(grads[-1], bins=bins)
                    #bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
                    #line_handle, = axs[row, col].plot(bin_centers, counts, '-', label=f"{loss_name} (line)", color=color)
                    handles.append(hist_handle[2][0])
                    labels.append(loss_name)
            axs[row, col].set_title(f"{net_prefix} - {layer}")
            axs[row, col].set_xlabel("Gradient Value")
            axs[row, col].set_ylabel("Frequency")
            axs[row, col].set_yscale("log")
            axs[row, col].set_ylim(y_min, y_max)
            axs[row, col].set_xlim(x_min, x_max)
            axs[row, col].grid(True)
            if handles and labels:
                axs[row, col].legend(handles, labels)
        for col in range(len(layers), max_layers):
            fig.delaxes(axs[row, col])
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'grad_hist_step_{step}.png'))
    plt.close()

    # Additional plot: gradient histograms for weights only (exclude biases)
    fig_w, axs_w = plt.subplots(num_nets, max_layers, figsize=(4 * max_layers, 3 * num_nets), squeeze=False)
    for row, (net_prefix, losses_dict) in enumerate(grad_values.items()):
        layers = net_layers[net_prefix]
        for col, layer in enumerate(layers):
            if "weight" not in layer:
                fig_w.delaxes(axs_w[row, col])
                continue
            handles = []
            labels = []
            for loss_name, layer_dict in losses_dict.items():
                grads = layer_dict.get(layer, [])
                if grads:
                    color = loss_colors.get(loss_name, None)
                    hist_handle = axs_w[row, col].hist(grads[-1], bins=bins, alpha=0.3, label=loss_name, color=color)
                    #counts, bin_edges = np.histogram(grads[-1], bins=bins)
                    #bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
                    #line_handle, = axs_w[row, col].plot(bin_centers, counts, '-', label=f"{loss_name} (line)", color=color)
                    handles.append(hist_handle[2][0])
                    labels.append(loss_name)
            axs_w[row, col].set_title(f"{net_prefix} - {layer} (weights only)")
            axs_w[row, col].set_xlabel("Gradient Value")
            axs_w[row, col].set_ylabel("Frequency")
            axs_w[row, col].set_yscale("log")
            axs_w[row, col].set_ylim(y_min, y_max)
            axs_w[row, col].set_xlim(x_min, x_max)
            axs_w[row, col].grid(True)
            if handles and labels:
                axs_w[row, col].legend(handles, labels)
        for col in range(len(layers), max_layers):
            if "weight" not in layers[col]:
                fig_w.delaxes(axs_w[row, col])
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'weights_step_{step}.png'))
    plt.close()

    # Additional plot: aggregate gradients across all layers for each net/loss
    bins = np.linspace(x_min, x_max, 100)
    fig2, axs2 = plt.subplots(num_nets, 1, figsize=(5, 3 * num_nets), squeeze=False)
    for row, (net_prefix, losses_dict) in enumerate(grad_values.items()):
        for loss_name, layer_dict in losses_dict.items():
            grads = []
            for layer_grads in layer_dict.values():
                if layer_grads:
                    grads.extend(layer_grads)
            if grads:
                color = loss_colors.get(loss_name, None)
                axs2[row, 0].hist(grads[-1], bins=bins, alpha=0.5, label=loss_name, color=color)
                #counts, bin_edges = np.histogram(grads, bins=bins)
                #bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
                #axs2[row, 0].plot(bin_centers, counts, '-', color=color)
        axs2[row, 0].set_title(f"{net_prefix} (all layers)")
        axs2[row, 0].set_xlabel("Gradient Value")
        axs2[row, 0].set_ylabel("Frequency")
        axs2[row, 0].set_yscale("log")
        axs2[row, 0].set_ylim(y_min, y_max)
        axs2[row, 0].set_xlim(x_min, x_max)
        axs2[row, 0].grid(True)
        axs2[row, 0].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'aggregate_step_{step}.png'))
    plt.close()
    
def hvp(loss, params, v):
    # Compute gradients with create_graph for Hessian-vector product
    grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
    grads_not_none = [g for g in grads if g is not None]
    idxs = [g.numel() for g in grads if g is not None]
    starts = np.cumsum([0] + idxs[:-1])
    # Slice v to match grads_not_none shapes
    v_filtered = torch.cat([
        v[start:start + g.numel()].to(g.device) for g, start in zip(grads_not_none, starts)
    ])
    flat_grads = torch.cat([g.reshape(-1) for g in grads_not_none])
    grad_v = torch.dot(flat_grads, v_filtered)
    # Compute Hessian-vector product, avoid moving to CPU
    hv = torch.autograd.grad(grad_v, params, retain_graph=True, allow_unused=True)
    return torch.cat([h.reshape(-1) for h in hv if h is not None])

def update_weights(
    model:torch.nn.Module,
    x_pde:torch.Tensor,
    x_ic:torch.Tensor,
    y_ic:torch.Tensor,
    optimizer:torch.optim.Optimizer,
    net_prefixes:list,
    losses:list,
    grad_values:dict,
    hessian_singular_values:dict,
    eigvals:dict,
    save_dir:str,
    step:int,
    weight_type:str,
    alpha:float
):
    '''
    # Here we collect the gradients per layer
    for net_prefix in net_prefixes:
        if net_prefix not in grad_values:
            grad_values[net_prefix] = {}
        for loss in losses:
            if loss not in grad_values[net_prefix]:
                grad_values[net_prefix][loss] = {}
            collect_grad_by_name(
                model, getattr(model, f'calc_{loss}_loss'), net_prefix, grad_values[net_prefix][loss], optimizer,
                x_pde=x_pde, x_ic=x_ic, y_ic=y_ic, retain_graph=True)

    _ = grad_collect(getattr(model, 'calc_ic_loss'), optimizer, model, retain_graph=False, x_ic=x_ic, y_ic=y_ic)


    plot_grad_hists(grad_values, save_dir=f'{save_dir}/grad_hists', step=step)

    for net_prefix in net_prefixes:
        if net_prefix not in eigvals:
            eigvals[net_prefix] = {}
        for objective in losses:
            if objective not in eigvals[net_prefix]:
                eigvals[net_prefix][objective] = []
            output = getattr(model, f'calc_{objective}')(x_pde=x_pde[:100], x_ic=x_ic[:100], y_ic=y_ic[:100])
            net_params = list(getattr(model, net_prefix).parameters())
            eigvals_calc = eigvals_from_output(output, net_params)
            try:
                eigvals[net_prefix][objective].append(eigvals_calc)
            except KeyError:
                eigvals[net_prefix][objective] = [eigvals_calc]
    plot_eigvals(eigvals, save_dir=f'{save_dir}/eigvals', step=step)
    '''
    if weight_type == 'ntk':
        # Compute the trace of the last eigenvalue matrices for each net/objective
        traces = {}
        for objective in losses:
            traces[objective] = 0
            for net_prefix in net_prefixes:
                eig_list = eigvals[net_prefix].get(objective, [])
                if eig_list:
                    traces[objective] += np.sum(eig_list[-1])

        all_trace = np.sum([trace for trace in traces.values() if trace is not None])

        new_weights = {objective: all_trace/traces[objective] for objective, trace in traces.items() if trace is not None}
    elif weight_type == 'grad':
        def grad_norm_reweight(loss_fn, retain_graph=False, **kwargs):
            optimizer.zero_grad()
            loss = loss_fn(**kwargs)
            loss = loss.sum() if loss.ndim > 0 else loss
            loss.backward(retain_graph=retain_graph)
            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e6)
            return norm.detach()

        grad_norms = {objective: grad_norm_reweight(getattr(model, f'calc_{objective}_loss'), retain_graph=True, x_pde=x_pde[:100], x_ic=x_ic[:100], y_ic=y_ic[:100]) for objective in losses}
        _ = grad_norm_reweight(getattr(model, f'calc_ic_loss'), retain_graph=False, x_ic=x_ic[:100], y_ic=y_ic[:100])
        loss_tot = sum(grad_norms.values())
        
        new_weights = {objective: loss_tot / grad_norm for objective, grad_norm in grad_norms.items() if grad_norm != 0}
        
    elif weight_type == 'static':
        new_weights = {objective: 1.0 for objective in losses}

    for weight, weight_val in new_weights.items():
        setattr(model, f'{weight}_weight', alpha*getattr(model, f'{weight}_weight') + (1-alpha) * weight_val)
        print(f"Updated {weight} weight: {getattr(model, f'{weight}_weight')}")
    
    # Power iteration
    def largest_singular_value(loss, params, device, iters=50):
        n_params = sum(p.numel() for p in params)
        v = torch.randn(n_params).to(device)
        v = v / v.norm()

        for _ in range(iters):
            Hv = hvp(loss, params, v)
            norm_Hv = Hv.norm()
            v = Hv / norm_Hv  # normalize
        return norm_Hv.item()
    
    # Compute largest singular value of Hessian for each net in DualPINN
    # Store results in a dict keyed by net_prefix

    all_internal_x = x_pde
    all_initial_x = x_ic
    all_initial_y = y_ic

    for net_prefix in net_prefixes:
        net_attr = getattr(model, net_prefix, None)
        if net_attr is not None:
            if net_prefix not in hessian_singular_values:
                hessian_singular_values[net_prefix] = []
            params = [p for p in net_attr.parameters() if p.requires_grad]
            loss = model.loss_fn(x_pde=all_internal_x, x_ic=all_initial_x, y_ic=all_initial_y)
            sigma_max = largest_singular_value(loss, params, device=next(model.parameters()).device)
            hessian_singular_values[net_prefix].append(sigma_max)
