import torch
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

from gla_v2.model import GLA_Model_v2
from gla_v2.model_mot import GLA_Model_MoT
from gla_v2.data_preprocess import data_prepare_gla_v2
from bitsandbytes.optim import AdamW8bit
from pypower.api import *
import numpy as np


def _cosine_on_intersection(ga_list, gb_list, eps=1e-12):
    dot = torch.tensor(0.0, device='cuda')
    na2 = torch.tensor(0.0, device=dot.device)
    nb2 = torch.tensor(0.0, device=dot.device)
    for ga, gb in zip(ga_list, gb_list):
        if ga is None or gb is None:
            continue
        # align dtype/device
        if gb.dtype != ga.dtype: gb = gb.to(dtype=ga.dtype)
        if gb.device != ga.device: gb = gb.to(device=ga.device)
        ga_f = ga.flatten()
        gb_f = gb.flatten()
        dot += torch.dot(ga_f, gb_f)
        na2 += torch.dot(ga_f, ga_f)
        nb2 += torch.dot(gb_f, gb_f)
    if na2.item() <= eps or nb2.item() <= eps:
        return float("nan")
    return (dot / (na2.sqrt() * nb2.sqrt())).item()

@torch.no_grad()
def _maybe_to_cpu(grads_per_task):
    out = []
    for grads in grads_per_task:
        out.append([g.detach().cpu() if g is not None else None for g in grads])
    return out

def grad_cos_matrix_on_intersection(
    model,
    task_losses,   # list[Tensor], 每个 task 的 loss（注意：各自的前向图需还在）
    params_selector=lambda n, p: p.requires_grad,
):
    """
    return cos[i,j](i<j)
    """
    named_params = [(n, p) for n, p in model.named_parameters() if params_selector(n, p)]
    params = [p for _, p in named_params]

    grads_per_task = []
    for t, loss in enumerate(task_losses):
        grads = torch.autograd.grad(
            loss, params,
            retain_graph=(t < len(task_losses) - 1),  # 只要不是最后一个，就保留计算图
            allow_unused=True,
            create_graph=False,
        )
        grads = [g.to(torch.float32) if g is not None else None for g in grads]
        grads_per_task.append(grads)

    T = len(task_losses)
    cos = [[0.0] * T for _ in range(T)]
    for i in range(T):
        for j in range(i + 1, T):
            cos_ij = _cosine_on_intersection(grads_per_task[i], grads_per_task[j])
            cos[i][j] = cos_ij
    return cos


def train_gla_v2(model, dataloader, optimizer, epochs=1, phase=1, prefix='', tasks=None):
    model.cuda()
    graph_token_per_task = [[] for _ in range(4)]
    if tasks is not None:
        loss_per_task = [0.0 for _ in range(len(tasks))]
        cos_matrix = np.zeros((len(tasks), len(tasks)))
    for epoch in range(epochs):
        cnt = 0
        model.train()
        max_loss = 0.0
        for batch in tqdm(dataloader):
            graph_data, edge_index, language_input, ground_truth_response, _, _ , _ = batch
            language_input, ground_truth_response = language_input[0], ground_truth_response[0]
            edge_index = edge_index.long().cuda().permute(0, 2, 1)
            graph_data = graph_data.float().cuda()
            with autocast(dtype=torch.bfloat16):
                # outputs = model(graph_data, edge_index, language_input, target_response=ground_truth_response)
                outputs, graph_tokens = model(graph_data, edge_index, language_input, target_response=ground_truth_response, return_graph_tokens=True)
                loss = outputs.loss

            if not torch.isnan(loss) and tasks is None:
                # Backward
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                # # scaler.scale(loss).backward()
                # # scaler.step(optimizer)
                # # scaler.update()

                # task_idx = cnt % 4
                # if graph_tokens.shape[1] == 126:
                #     graph_tokens = graph_tokens[:, :118, :]
                # graph_token_per_task[task_idx].append(graph_tokens.float().detach().cpu().numpy())
                # if (cnt+1) == 40:
                #     graph_token_per_task = np.asarray(graph_token_per_task)
                #     np.save(f'graph_tokens_ieee300.npy', graph_token_per_task)
                #     import ipdb
                #     ipdb.set_trace()

                max_loss = max(max_loss, loss.item())
                if cnt % 200 == 0:
                    print(f"Phase {phase}, Epoch {epoch}, prefix={prefix}, max loss within period: {max_loss:.4f}, loss: {loss.item():.4f}")
                    max_loss = 0.0
                if cnt % 10000 == 0:
                    model.save(model_path=f'./openglav2_{phase}_{prefix}_checkpoint.p', tokenizer_path=f'./openglav2_tokenizer_{phase}_{prefix}_checkpoint')

            if tasks is not None:
                if cnt == 100 * len(tasks):
                    cos_matrix /= 100
                    for i in range(len(tasks)):
                        for j in range(i + 1, len(tasks)):
                            print(f"({i},{j}) {tasks[i]} vs {tasks[j]}  cos={cos_matrix[i][j]:.4f}")
                    import ipdb
                    ipdb.set_trace()
                task_index = cnt % len(tasks)
                loss_per_task[task_index] = loss
                print(f'{task_index}, loss: {loss_per_task[task_index].item():.4f}')

                if cnt > 0 and (cnt+1) % len(tasks) == 0:
                    cos_upper = grad_cos_matrix_on_intersection(model, loss_per_task, move_to_cpu=True)
                    cos_matrix += np.asarray(cos_upper)
            cnt += 1

    model.save(model_path=f'./openglav2_{phase}_{prefix}.p', tokenizer_path=f'./openglav2_tokenizer_{phase}_{prefix}')

if __name__ == '__main__':

    # Train GCN model
    networks = [
        'IEEE14',
        'IEEE39',
        'IEEE57',
        'SG126',
        'IEEE300',
        # 'Texas2000'
    ]
    ppcs = {
        'IEEE14': case14(),
        'IEEE39': case39(),
        'IEEE57': case57(),
        'IEEE118': case118(),
        'SG126': case126(),
        'IEEE300': case300(),
        'Texas2000': case2000()
    }
    tasks = [
        'opf',
        'lmp_pred',
        'state_est',
        'fault_detect',
    ]


    # Train GLA model
    float_disc = True
    use_diffusion = False
    use_MoT = True
    use_lora = True
    node_parallel = False
    epochs = [1, 1]
    # llm_model = 'Qwen/Qwen3-1.7B'
    # llm_model = 'google/gemma-3-1b-pt'
    # llm_model = 'meta-llama/Llama-3.2-1b'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    # llm_model = 'meta-llama/Llama-3.2-3b'
    llm_model = 'meta-llama/Llama-3.1-8b'
    graph_input_dim = 64
    graph_hidden_dim = 1024
    data_amount = None
    prefix = f'{llm_model.split("/")[1]}-{graph_input_dim}-{graph_hidden_dim}' \
             f'{"-diffusion" if use_diffusion else ""}' \
             f'{"-MoT" if use_MoT else ""}' \
             f'{f"-{data_amount}" if data_amount is not None else ""}' \
             f'-RoPE13'

    idx = 0
    samples = data_prepare_gla_v2(networks, ppcs, tasks, float_disc=float_disc, data_amount=data_amount)
    train_loader = DataLoader(samples, batch_size=1, shuffle=False)
    Model = GLA_Model_MoT if use_MoT else GLA_Model_v2
    # for training_phase in range(1, 3):  # 1-pretrain W and Graph Encoder, freeze LLM; 2-end2end
    training_phase = 2
    idx = 1
    model = Model(graph_input_dim=graph_input_dim, graph_hidden_dim=graph_hidden_dim, language_model_name=f'{llm_model}',
                  use_lora=use_lora, float_disc=float_disc, phase=training_phase, use_diffusion=use_diffusion, node_parallel=node_parallel).cuda()
    if training_phase == 2:
        model.load(
            model_path=f'./saved_models/openglav2_1_{prefix}.p',
            tokenizer_path=f'./saved_models/openglav2_tokenizer_1_{prefix}'
        )
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4 if training_phase == 1 else 2e-5, fused=True)
    # optimizer = AdamW8bit(model.parameters(), lr=2e-4 if training_phase == 1 else 2e-5)     # could save lots of GPU memory, hurt performance
    train_gla_v2(model, train_loader, optimizer, epochs=epochs[idx], phase=training_phase, prefix=prefix,
                 # tasks=tasks    # if analyze gradient similarities across tasks
                 )
    idx += 1