import os

from config import gen_args
from data import normalize, denormalize
from graph_models.CompositionalKoopmanOperators import CompositionalKoopmanOperators
from graph_models.ControllableEmbedding import ControllableEmbedding
from models.KoopmanBaselineModel import KoopmanBaseline
from physics_engine import SoftEngine, RopeEngine, SwimEngine, VoltageControlEngine
from utils import *
from utils import to_var, to_np, Tee
from progressbar import ProgressBar
import time
import matplotlib.pyplot as plt

args = gen_args()
print_args(args)
'''
args.fit_num is # of trajectories used for SysID
'''
assert args.group_size - 1 >= args.fit_num

data_names = ['attrs', 'states', 'actions']
prepared_names = ['attrs', 'states', 'actions', 'rel_attrs']

data_dir = os.path.join(args.dataf, args.eval_set)

print(f"Load stored dataset statistics from {args.stat_path}!")
stat = load_data(data_names, args.stat_path)

if args.env == 'Rope':
    engine = RopeEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
elif args.env == 'VoltageControl':
    engine = VoltageControlEngine(args.state_dim, args.action_dim, args.param_dim, args.dt)
elif args.env == 'Soft':
    engine = SoftEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
elif args.env == 'Swim':
    engine = SwimEngine(args.dt, args.state_dim, args.action_dim, args.param_dim)
else:
    assert False


os.system('mkdir -p ' + args.evalf)
log_path = os.path.join(args.evalf, 'log.txt')
tee = Tee(log_path, 'w')

'''
model
'''
# build model
use_gpu = torch.cuda.is_available()
if not args.baseline:
    # Determine which model to use based on configuration
    is_controllable_embedding = hasattr(args, 'embed_mode') or args.fit_type in ['Gaussian_reweight', 'Hom', 'dense']
    
    if is_controllable_embedding:
        """ Controllable Embedding model """
        print("Using Controllable Embedding model")
        model = ControllableEmbedding(args, residual=False, use_gpu=use_gpu)
    else:
        """ Compositional Koopman model """
        print("Using Compositional Koopman Operators model")
        model = CompositionalKoopmanOperators(args, residual=False, use_gpu=use_gpu)

    # load pretrained checkpoint
    if args.eval_epoch == -1:
        model_path = os.path.join(args.outf, 'net_best.pth')
    else:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.eval_epoch, args.eval_iter))
    
    print("Loading saved checkpoint from %s" % model_path)
    device = torch.device('cuda:0') if use_gpu else torch.device('cpu')
    
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
    except Exception as e:
        print(f"Warning: Error loading model state dict: {e}")
        print("Trying to load with different compatibility settings...")
        model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    
    model.eval()
    if use_gpu: model.cuda()

else:
    """ Koopman Baseline """
    model = KoopmanBaseline(args)

'''
eval
'''

def get_more_trajectories(roll_idx):
    group_idx = roll_idx // args.group_size
    offset = group_idx * args.group_size

    all_seq = [[], [], [], []]

    for i in range(1, args.fit_num + 1):
        new_idx = (roll_idx + i - offset) % args.group_size + offset
        seq_data = load_data(prepared_names, os.path.join(data_dir, str(new_idx) + '.rollout.h5'))
        for j in range(4):
            all_seq[j].append(seq_data[j])

    all_seq = [np.array(all_seq[j], dtype=np.float32) for j in range(4)]
    return all_seq

def eval(idx_rollout, video=True):
    print(f'\n=== Forward Simulation on Example {idx_rollout} ===')

    seq_data = load_data(prepared_names, os.path.join(data_dir, str(idx_rollout) + '.rollout.h5'))
    attrs, states, actions, rel_attrs = [to_var(d.copy(), use_gpu=use_gpu) for d in seq_data]

    seq_data = denormalize(seq_data, stat)
    attrs_gt, states_gt, action_gt = seq_data[:3]

    param_file = os.path.join(data_dir, str(idx_rollout // args.group_size) + '.param')
    param = torch.load(param_file, weights_only=False)

    '''
    fit data - system identification on multiple trajectories
    '''
    fit_data = get_more_trajectories(idx_rollout)
    fit_data = [to_var(d, use_gpu=use_gpu) for d in fit_data]
    bs = args.fit_num

    ''' T x N x D (denormalized)'''
    states_pred = states_gt.copy()

    ''' T x N x D (normalized)'''
    s_pred = states.clone()

    '''
    system identification - find dynamics matrices A and B
    '''
    attrs_flat = get_flat(fit_data[0])
    states_flat = get_flat(fit_data[1])
    actions_flat = get_flat(fit_data[2])
    rel_attrs_flat = get_flat(fit_data[3])

    # Encode states to latent dynamics space
    g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep)
    g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:])

    # Prepare for system identification
    G_tilde = g[:, :-1]  # Current encoded states
    H_tilde = g[:, 1:]   # Next encoded states
    U_tilde = fit_data[2][:, :-1]  # Actions

    G_tilde = get_flat(G_tilde, keep_dim=True)
    H_tilde = get_flat(H_tilde, keep_dim=True)
    U_tilde = get_flat(U_tilde, keep_dim=True)

    print("Performing system identification...")
    _t = time.time()
    A, B, fit_err = model.system_identify(
        G=G_tilde, H=H_tilde, U=U_tilde, rel_attrs=fit_data[3][:1, 0], I_factor=args.I_factor)
    _t = time.time() - _t
    print(f"System identification completed in {_t:.2f} seconds with error {fit_err:.6f}")

    '''
    predict - forward simulation using identified dynamics
    '''
    print("Starting forward prediction...")
    g = model.to_g(attrs, states, rel_attrs, args.pstep)

    pred_g = None
    pred_g_list = []
    
    for step in range(0, args.time_step - 1):
        # prepare input data
        if step == 0:
            current_s = states[step:step + 1]
            current_g = g[step:step + 1]
            states_pred[step] = states_gt[step]
        else:
            '''current state'''
            if args.eval_type == 'valid':
                current_s = states[step:step + 1]
            elif args.eval_type == 'rollout':
                current_s = s_pred[step:step + 1]

            '''current g'''
            if args.eval_type in {'valid', 'rollout'}:
                current_g = model.to_g(attrs[step:step + 1], current_s, rel_attrs[step:step + 1], args.pstep)
            elif args.eval_type == 'koopman':
                current_g = pred_g

        '''next g - step through dynamics in latent space'''
        pred_g = model.step(g=current_g, u=actions[step:step + 1], rel_attrs=rel_attrs[step:step + 1])

        '''decode s - map from latent to state space'''
        pred_s = model.to_s(attrs=attrs[step:step + 1], gcodes=pred_g,
                            rel_attrs=rel_attrs[step:step + 1], pstep=args.pstep)

        pred_s_np_denorm = denormalize([to_np(pred_s)], [stat[1]])[0]

        states_pred[step + 1:step + 2] = pred_s_np_denorm
        
        # Update positions based on velocities (for physical consistency)
        d = args.state_dim // 2
        states_pred[step + 1:step + 2, :, :d] = states_pred[step:step + 1, :, :d] + \
             args.dt * states_pred[step + 1:step + 2, :, d:]
                                                
        s_pred_next = normalize([states_pred[step + 1:step + 2]], [stat[1]])[0]
        s_pred[step + 1:step + 2] = to_var(s_pred_next, use_gpu=use_gpu)
        
        # Store predicted latent state for visualization
        pred_g_list.append(pred_g.detach().cpu().numpy())

    # Visualize predictions
    print("Generating prediction visualization...")
    
    # Calculate prediction error
    mse = np.mean((states_gt[:, :, :] - states_pred[:, :, :])**2)
    print(f"Mean squared error: {mse:.6f}")
    
    # Display position and velocity plots for a sample object
    sample_obj = min(2, states_gt.shape[1]-1)  # Choose object index to visualize
    
    plt.figure(figsize=(12, 8))
    plt.suptitle(f'Predictions for Object {sample_obj} using {type(model).__name__}', fontsize=14)
    
    # Position plots
    for i, dim_name in enumerate(['x', 'y']):
        plt.subplot(2, 2, i+1)
        plt.plot(states_gt[:100, sample_obj, i], label='True', linestyle='-', color='blue')
        plt.plot(states_pred[:100, sample_obj, i], label='Predicted', linestyle='--', color='red')
        plt.title(f'Position ({dim_name})')
        plt.xlabel('Time step')
        plt.ylabel('Position')
        plt.legend()
    
    # Velocity plots
    d = args.state_dim // 2
    for i, dim_name in enumerate(['vx', 'vy']):
        plt.subplot(2, 2, i+3)
        plt.plot(states_gt[:100, sample_obj, i+d], label='True', linestyle='-', color='blue')
        plt.plot(states_pred[:100, sample_obj, i+d], label='Predicted', linestyle='--', color='red')
        plt.title(f'Velocity ({dim_name})')
        plt.xlabel('Time step')
        plt.ylabel('Velocity')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.evalf, f'pred_vis_{idx_rollout}.png'))
    plt.show()
    plt.close()
    
    # Optionally render video simulation
    if video and hasattr(engine, 'render'):
        print("Rendering simulation video...")
        engine.render(states_pred, seq_data[2], param, act_scale=args.act_scale, video=True, image=True,
                    path=os.path.join(args.evalf, str(idx_rollout) + '.pred'),
                    states_gt=states_gt)

    return mse

if __name__ == '__main__':
    num_train = int(args.n_rollout * args.train_valid_ratio)
    num_valid = args.n_rollout - num_train

    ls_rollout_idx = np.arange(0, num_valid, num_valid // args.n_splits)

    if args.demo:
        ls_rollout_idx = np.arange(8) * 25

    # Store MSE for all evaluated trajectories
    all_mse = []
    
    print(f"Evaluating on {len(ls_rollout_idx)} trajectories...")
    for roll_idx in ls_rollout_idx:
        mse = eval(roll_idx)
        all_mse.append(mse)
    
    # Report overall performance
    avg_mse = np.mean(all_mse)
    print(f"\n=== Evaluation Complete ===")
    print(f"Average MSE across all trajectories: {avg_mse:.6f}")
    
    # Save summary results
    with open(os.path.join(args.evalf, 'summary.txt'), 'w') as f:
        f.write(f"Model: {type(model).__name__}\n")
        f.write(f"Evaluation type: {args.eval_type}\n")
        f.write(f"Average MSE: {avg_mse:.6f}\n")
        f.write(f"Individual trajectory MSE values:\n")
        for i, (idx, mse) in enumerate(zip(ls_rollout_idx, all_mse)):
            f.write(f"Trajectory {idx}: {mse:.6f}\n")
