import os
import numpy as np
import cvxpy as cp
from cvxpy import quad_form

import torch
import torch.optim as optim
import torch.nn.functional as F

from physics_engine import VoltageControlEngine
from data import load_data, normalize, denormalize
from graph_models.ControllableEmbedding import ControllableEmbedding
from utils import to_var, to_np, Tee, norm, get_flat, print_args

from progressbar import ProgressBar

from config import gen_args
from socket import gethostname

import matplotlib.pyplot as plt


args = gen_args()

os.system("mkdir -p " + args.shootf)

log_path = os.path.join(args.shootf, 'log.txt')
tee = Tee(log_path, 'w')

print_args(args)

print(f"Load stored dataset statistics from {args.stat_path}!")
stat = load_data(args.data_names, args.stat_path)

data_names = ['attrs', 'states', 'actions']
prepared_names = ['attrs', 'states', 'actions', 'rel_attrs']
data_dir = os.path.join(args.dataf, args.shoot_set)

if args.shoot_set == 'extra' and gethostname().startswith('netmit'):
    data_dir = args.dataf + '_' + args.shoot_set

'''
model
'''
# build model
use_gpu = torch.cuda.is_available()
if not args.baseline:
    """ Controllable Embedding model """
    model = ControllableEmbedding(args, residual=False, use_gpu=use_gpu)

    # load pretrained checkpoint
    if args.shoot_epoch == -1:
        model_path = os.path.join(args.outf, 'net_best.pth')
    else:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.shoot_epoch, args.shoot_iter))

    print("Loading saved ckp from %s" % model_path)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda:0' if use_gpu else 'cpu')))
    model.eval()
    if use_gpu: model.cuda()

else:
    """ Baseline model """
    from models.KoopmanBaselineModel import KoopmanBaseline
    model = KoopmanBaseline(args)

'''
shoot
'''

if args.env == 'Rope':
    engine = RopeEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
elif args.env == 'VoltageControl':
    engine = VoltageControlEngine(args.state_dim, args.action_dim, args.param_dim, args.dt)
elif args.env == 'Soft':
    engine = SoftEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
elif args.env == 'Swim':
    engine = SwimEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
else:
    assert False


def get_more_trajectories(roll_idx):
    """
    Get additional trajectories from the same parameter group for system identification.
    
    Args:
        roll_idx: Base rollout index
        
    Returns:
        List of trajectory data tensors
    """
    fits = []
    for fit_idx in range(roll_idx // args.group_size * args.group_size,
                        (roll_idx // args.group_size + 1) * args.group_size):
        if fit_idx != roll_idx:
            if args.fit_num == 1:
                # Use a deterministic trajectory for single-trajectory fitting
                fits = [fit_idx]
                break
            fits.append(fit_idx)
            if len(fits) >= args.fit_num - 1:
                break
                
    fits.append(roll_idx)
    
    print(f"Fitting trajectories: {fits}")
    
    fit_data = None
    for i, idx in enumerate(fits):
        seq_data = load_data(prepared_names, os.path.join(data_dir, str(idx) + '.rollout.h5'))
        
        # Initialize storage on first iteration
        if fit_data is None:
            fit_data = [[d.copy()] for d in seq_data]
        else:
            for j, d in enumerate(seq_data):
                fit_data[j].append(d.copy())
                
    fit_data = [np.array(d) for d in fit_data]
    return fit_data

def mpc_qp(g_cur, g_goal, time_cur, T, rel_attrs, A_t, B_t, Q, R, node_attrs=None,
           actions=None, gt_info=None):
    """
    Solve a Model Predictive Control problem using quadratic programming.
    
    Args:
        g_cur: Current latent state
        g_goal: Goal latent state
        time_cur: Current timestep
        T: Prediction horizon
        rel_attrs: Relation attributes for graph
        A_t, B_t: System dynamics matrices (transposed)
        Q, R: State and control cost matrices
        node_attrs: Node attributes (optional)
        actions: Ground truth actions (optional)
        gt_info: Additional ground truth info (optional)
        
    Returns:
        Sequence of optimal control actions
    """
    n_obj = g_cur.shape[0]
    g_dim = g_cur.shape[1]
    
    if args.fit_type == 'structured':
        A = cp.Parameter((g_dim * n_obj, g_dim * n_obj))
        B = cp.Parameter((g_dim * n_obj, args.action_dim * n_obj))
        A.value = A_t
        B.value = B_t
        
        g = cp.Variable((T * n_obj, g_dim))
        u = cp.Variable((T * n_obj, args.action_dim))
        
        cost = 0
        constraints = []
        
        for idx in range(n_obj):
            # Constrain the initial g
            constraints.append(g[idx] == g_cur[idx])
            
            for t in range(1, T):
                cur_idx = t * n_obj + idx
                prv_idx = (t - 1) * n_obj + idx
                
                # Define control constraints
                zero_normed = -stat[2][:, 0] / stat[2][:, 1]
                act_scale_max_normed = (2 - stat[2][:, 0]) / stat[2][:, 1]
                act_scale_min_normed = (-2 - stat[2][:, 0]) / stat[2][:, 1]
                
                # Apply action bounds
                constraints.append(u[prv_idx] >= act_scale_min_normed)
                constraints.append(u[prv_idx] <= act_scale_max_normed)
                
                # Apply specific constraints for VoltageControl environment
                if args.env == 'VoltageControl':    
                    # Only apply actions to generator nodes
                    if node_attrs is not None and node_attrs[idx, 0] < 0.5:  # If not a generator
                        constraints.append(u[prv_idx][:] == zero_normed)
                
                # Define system dynamics
                augG = cp.hstack([g[prv_idx], cp.Parameter(g_dim)])
                augU = cp.hstack([u[prv_idx], cp.Parameter(args.action_dim)])
                constraints.append(g[cur_idx] == A @ augG + B @ augU)
                
                # Penalize control effort
                cost += quad_form(u[prv_idx] - zero_normed, R)
                
            # Terminal cost
            cost += quad_form(g[(T - 1) * n_obj + idx] - g_goal[idx], Q)
            
        # Solve optimization problem
        prob = cp.Problem(cp.Minimize(cost), constraints)
        try:
            prob.solve(solver=cp.OSQP, verbose=False)
            status = prob.status
        except:
            # Try alternative solver if OSQP fails
            try:
                prob.solve(solver=cp.SCS, verbose=False)
                status = prob.status
            except:
                status = 'failed'
        
        if status == 'optimal' or status == 'optimal_inaccurate':
            # Extract optimal control sequence
            u_val = u.value
            u_opt = []
            for i in range(1, T):
                u_slice = u_val[(i-1)*n_obj : i*n_obj]
                u_opt.append(u_slice)
            return np.array(u_opt)
        else:
            print(f"Warning: Optimization failed with status: {status}")
            # Return zeros as fallback
            return np.zeros((T-1, n_obj, args.action_dim))
        
    elif args.fit_type == 'diagonal':
        print("Diagonal fit not implemented for VoltageControl")
        return np.zeros((T-1, n_obj, args.action_dim))
    
    else:
        print(f"Unknown fit_type: {args.fit_type}")
        return np.zeros((T-1, n_obj, args.action_dim))


def shoot_mpc_qp(roll_idx):
    """
    Perform model-based control using MPC on a specific rollout.
    
    Args:
        roll_idx: Rollout index to control
    """
    print(f'\n=== Model Based Control on Example {roll_idx} ===')

    '''
    Load data
    '''
    seq_data = load_data(prepared_names, os.path.join(data_dir, str(roll_idx) + '.rollout.h5'))
    attrs, states, actions, rel_attrs = [to_var(d.copy(), use_gpu=use_gpu) for d in seq_data]

    seq_data = denormalize(seq_data, stat)
    attrs_gt, states_gt, actions_gt = seq_data[:3]

    '''
    Setup engine
    '''
    # Determine number of nodes from the data
    n_obj = attrs.shape[1]
    print(f"Number of nodes in system: {n_obj}")

    # Initialize engine with correct number of nodes
    engine = VoltageControlEngine(num_nodes=n_obj)
    print(f"Engine initialized with {engine.vc_sim.num_nodes} nodes")

    '''
    System identification
    '''
    print('===> System identification!')
    fit_data = get_more_trajectories(roll_idx)
    fit_data = [to_var(d, use_gpu=use_gpu) for d in fit_data]
    bs = args.fit_num

    attrs_flat = get_flat(fit_data[0])
    states_flat = get_flat(fit_data[1])
    actions_flat = get_flat(fit_data[2])
    rel_attrs_flat = get_flat(fit_data[3])

    # Convert states to latent representation
    g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep)
    g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:])

    G_tilde = g[:, :-1]
    H_tilde = g[:, 1:]
    U_left = fit_data[2][:, :-1]

    G_tilde = get_flat(G_tilde, keep_dim=True)
    H_tilde = get_flat(H_tilde, keep_dim=True)
    U_left = get_flat(U_left, keep_dim=True)

    A, B, fit_err = model.system_identify(G=G_tilde, H=H_tilde, U=U_left,
                                          rel_attrs=fit_data[3][:1, 0], I_factor=args.I_factor)

    '''
    Control
    '''
    print('===> Model based control start!')
    # Set start and goal steps
    start_step = args.roll_start
    goal_step = args.roll_step + args.roll_start
    
    print(f"Control from step {start_step} to step {goal_step}")
    
    # Get initial latent state
    g_start_v = model.to_g(attrs=attrs[start_step:start_step + 1], states=states[start_step:start_step + 1],
                           rel_attrs=rel_attrs[start_step:start_step + 1], pstep=args.pstep)
    g_start = to_np(g_start_v[0])
    
    # Define goal state (all voltages at 1.0 per unit, derivatives at 0)
    node_tensor = torch.tensor([1.0, 0.0]).repeat(n_obj, 1) 
    node_tensor = node_tensor.unsqueeze(0)
    
    # Get goal in latent space
    g_goal_v = model.to_g(attrs=attrs[goal_step:goal_step + 1], states=node_tensor,
                          rel_attrs=rel_attrs[goal_step:goal_step + 1], pstep=args.pstep)
    g_goal = to_np(g_goal_v[0])

    # Initialize engine with starting state
    print("Initial state:", states_gt[start_step])
    states_start = states_gt[start_step]
    reshape_state = states_start.T.reshape(-1)
    engine.set_state(reshape_state)
    
    # Verify engine state
    print("Engine state after initialization:", engine.get_state())
    
    # Define goal state
    states_goal = torch.tensor([1.0, 0.0]).repeat(n_obj, 1) 
    
    # Arrays to store trajectory
    states_roll = np.zeros((args.roll_step + 1, n_obj, args.state_dim))
    states_roll[0] = states_start
    
    control = np.zeros((args.roll_step + 1, n_obj, args.action_dim))
    
    # Run MPC control loop
    bar = ProgressBar()
    for step in bar(range(args.roll_step)):
        # Get current state in latent space
        states_input = normalize([states_roll[step:step + 1]], [stat[1]])[0]
        states_input_v = to_var(states_input, use_gpu=use_gpu)
        g_cur_v = model.to_g(attrs=attrs[:1], states=states_input_v,
                            rel_attrs=rel_attrs[:1], pstep=args.pstep)
        g_cur = to_np(g_cur_v[0])

        # Setup MPC parameters
        T = min(args.roll_step - step + 1, 20)  # Limit prediction horizon
        
        A_v, B_v = model.A, model.B
        A_t = to_np(A_v[0]).T
        B_t = to_np(B_v[0]).T
        
        # Cost matrices
        if not args.baseline:
            Q = np.eye(args.g_dim)
        else:
            Q = np.eye(g_goal.shape[-1])
            
        # Control cost factor
        R_factor = 0.0001  # Specific to VoltageControl
        R = np.eye(args.action_dim) * R_factor
        
        # Generate control action
        rel_attrs_np = to_np(rel_attrs)[0]
        if step % args.feedback == 0:
            node_attrs = to_np(attrs)[0]  # Node attributes for generator/load identification
            u = mpc_qp(g_cur, g_goal, step, T, rel_attrs_np, A_t, B_t, Q, R, node_attrs=node_attrs,
                      actions=to_np(actions[step:]),
                      gt_info=[None, states_gt[goal_step:goal_step + 1], attrs[step:step + T],
                              rel_attrs[step:step + T]])
        else:
            u = u[1:]
        
        # Apply control action to engine
        engine.param = u[0].reshape(-1)
        engine.step()
        states_roll[step + 1] = engine.get_state()
    
    # Compute final state and error
    states_result = states_roll[args.roll_step]
    print("Final state:", states_result)
    
    # Normalize states for comparison
    states_goal_normalized = normalize([states_goal], [stat[1]])[0]
    states_result_normalized = normalize([states_result], [stat[1]])[0]
    
    # Calculate and report error
    rel_error = norm(states_goal.numpy().flatten() - states_result.flatten())/norm(states_result.flatten())
    print(f"Relative error: {rel_error}")
    
    # Plot voltage trajectory
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Voltage Magnitudes")
    plt.plot(states_roll[:, :, 0], linestyle='-', alpha=0.7)
    plt.axhline(y=1.0, color='k', linestyle='--', label="Reference")
    plt.xlabel("Time Step")
    plt.ylabel("Voltage (p.u.)")
    
    plt.subplot(1, 2, 2)
    plt.title("Voltage Derivatives")
    plt.plot(states_roll[:, :, 1], linestyle='-', alpha=0.7)
    plt.axhline(y=0.0, color='k', linestyle='--', label="Reference")
    plt.xlabel("Time Step")
    plt.ylabel("dV/dt")
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.shootf, f"voltage_trajectory_{roll_idx}.png"))
    plt.close()


if __name__ == '__main__':
    os.system('mkdir -p ' + args.shootf)
    num_train = int(args.n_rollout * args.train_valid_ratio)
    num_valid = args.n_rollout - num_train
    ls_rollout_idx = np.arange(0, num_valid, num_valid // args.group_size // 5)

    if args.demo:
        ls_rollout_idx = np.arange(8) * 25

    for roll_idx in ls_rollout_idx:
        shoot_mpc_qp(roll_idx)
