import os
import pickle
import torch
import numpy as np
from tqdm import tqdm
import wandb
import ray
from scipy.optimize import linprog
from aux.ppo import PPOBinaryActorCritic, ppo_update
from aux.seed import set_seed
from aux.CPD import CPD
from test import test_run
import time
from aux.solver import solve_ilp

def collect_ppo_dataset(data, model_state_dict, dim, con_num, T, gamma, lam, c, adv_type, device_str):
    @ray.remote(num_gpus=0.02)
    def collect_trajectory(i, data_i):
        device = torch.device(device_str)
        model = PPOBinaryActorCritic(3 + 3 * dim + con_num).to(device)
        model.load_state_dict(model_state_dict)

        A = torch.tensor(data_i['A'], dtype=torch.float32, device=device)
        b = torch.tensor(data_i['b'], dtype=torch.float32, device=device)
        c_all = torch.tensor(data_i['c_all'], dtype=torch.float32, device=device)
        f_all = torch.tensor(data_i['f_all_IP'], dtype=torch.float32, device=device)
        split_points = data_i['split_points']
        sol = linprog(np.ones(dim), A_ub=A.cpu().numpy(), b_ub=b.cpu().numpy(), bounds=[(0, None)] * dim, method='highs')
        objVal, x_val = solve_ilp(A.cpu().numpy(), b.cpu().numpy(), np.ones(dim))
        if not sol.success:
            return None

        x_sol = torch.tensor(x_val, dtype=torch.float32, device=device)
        mu_sol = torch.tensor(sol.lower.marginals, dtype=torch.float32, device=device)
        lambda_sol = torch.tensor(sol.ineqlin.marginals, dtype=torch.float32, device=device)


        model_time = 0
        model_sample_size = 0
        state = torch.cat([
            torch.tensor([0.0], device=device),
            torch.tensor([0.0], device=device),
            torch.tensor([0.0], device=device),
            (torch.ones(dim, device=device) - lambda_sol @ A - mu_sol),
            x_sol,
            mu_sol,
            lambda_sol,
        ])

        all_states, all_actions, all_rewards, all_log_probs, all_values = [], [], [], [], []

        for t in tqdm(range(1, T)):
            cp = CPD(c_all[:t].cpu().numpy().astype(np.float64))
            start = cp[-1] if cp else 0
            c_avg = c_all[start:t].mean(dim=0)

            state[0] = float(model_time - t)
            state[1] = float(start - t)
            state[2] = float(model_sample_size)
            state[3:3 + dim] = c_avg - lambda_sol @ A - mu_sol
            state[3 + dim:3 + 2 * dim] = x_sol
            state[3 + 2 * dim:3 + 3 * dim] = mu_sol
            state[3 + 3 * dim:3 + 3 * dim + con_num] = lambda_sol

            action, log_prob, value = model.act(state.clone())

            if action == 1:
                sol = linprog(c_avg.cpu().numpy(), A_ub=A.cpu().numpy(), b_ub=b.cpu().numpy(), bounds=[(0, None)] * dim, method='highs')
                objVal, x_val = solve_ilp(A.cpu().numpy(), b.cpu().numpy(), c_avg.cpu().numpy())
                if sol.success:
                    x_sol = torch.tensor(x_val, dtype=torch.float32, device=device)
                    mu_sol = torch.tensor(sol.lower.marginals, dtype=torch.float32, device=device)
                    lambda_sol = torch.tensor(sol.ineqlin.marginals, dtype=torch.float32, device=device)
                model_time = t
                model_sample_size = t - start

            reg = torch.dot(c_all[t], x_sol) - f_all[t] + c * action
            reward = -reg

            all_states.append(state.clone())
            all_actions.append(torch.tensor(action, device=device))
            all_rewards.append(reward)
            all_log_probs.append(log_prob)
            all_values.append(value)

        if len(all_rewards) == 0:
            return None

        states = torch.stack(all_states)
        actions = torch.stack(all_actions)
        rewards = torch.stack(all_rewards)
        log_probs = torch.stack(all_log_probs)
        values = torch.stack(all_values)

        if adv_type == 'td':
            advantages = torch.zeros_like(rewards)
            returns = torch.zeros_like(rewards)
            for t in range(len(rewards) - 1):
                returns[t] = rewards[t] + gamma * values[t + 1]
                advantages[t] = returns[t] - values[t]
            returns[-1] = rewards[-1]
            advantages[-1] = returns[-1] - values[-1]

        return {
            "states": states,
            "actions": actions,
            "log_probs": log_probs,
            "values": values.detach(),
            "advantages": advantages,
            "returns": returns
        }

    futures = [
        collect_trajectory.remote(i, data[i])
        for i in range(len(data))
    ]
    results = ray.get(futures)
    results = [r for r in results if r is not None]

    ppo_dataset = {
        "states": torch.cat([r["states"] for r in results], dim=0),
        "actions": torch.cat([r["actions"] for r in results], dim=0),
        "log_probs": torch.cat([r["log_probs"] for r in results], dim=0),
        "values": torch.cat([r["values"] for r in results], dim=0),
        "advantages": torch.cat([r["advantages"] for r in results], dim=0),
        "returns": torch.cat([r["returns"] for r in results], dim=0),
    }
    return ppo_dataset


if __name__ == "__main__":
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dim = 100
    con_num = 50
    T = 1000
    gamma = 0.9
    lam = 0.9
    c = 10.0
    adv_type = 'td'

    model_path = 'result.pth'
    train_data_path = 'train_dataset_01_1k.pkl'

    model = PPOBinaryActorCritic(3 + 3 * dim + con_num)
    model = torch.nn.DataParallel(model).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    if os.path.exists(model_path):
        model.module.load_state_dict(torch.load(model_path, map_location=device))  
        print(f"Loaded model from {model_path}")

    data = pickle.load(open(train_data_path, 'rb'))[:]

    ray.init(ignore_reinit_error=True)

    ppo_dataset = collect_ppo_dataset(
        data, model.module.state_dict(), dim, con_num, T, gamma, lam, c, adv_type, str(device)
    )

    for k, v in ppo_dataset.items():
        print(f"{k}: shape={v.shape}, total={v.numel()}")

    epoch_losses = ppo_update(model, optimizer, ppo_dataset, device=device, epochs=100)

    timestamp = time.strftime("%Y%m%d_%H%M%S")
    torch.save(model.module.state_dict(), model_path)  # ✅ 注意 .module
    print(f"Model saved to {model_path}")

    with open('test_dataset_01.pkl', 'rb') as f:
        dataset = pickle.load(f)  
    test_run_ray = ray.remote(num_gpus=0.05)(test_run)
    futures = [
        test_run_ray.remote(i, data, c, gamma, model_path=model_path)
        for i, data in enumerate(dataset)
    ]
    results = ray.get(futures)
    avg_cost = sum(r['total_cost'] for r in results) / len(results)
    avg_retrain_num = sum(r['retrain_num'] for r in results) / len(results)
    print(f"✅ Average Total Cost over {len(results)} test cases: {avg_cost:.2f}")
    print(f"✅ Average Retrain Num over {len(results)} test cases: {avg_retrain_num:.2f}")

    ray.shutdown()
