import os
import random

import multiprocessing as mp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from progressbar import ProgressBar
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

from config import gen_args
from data import PhysicsDataset
from data import load_data
from graph_models.ControllableEmbedding import ControllableEmbedding
from utils import count_parameters, Tee, AverageMeter, rand_int, mix_iters, get_flat, print_args


class ShuffledDataset(Dataset):
    def __init__(self,
                 mother_dataset,
                 idx,
                 batch_size,
                 args):
        """
        Dataset wrapper that shuffles and batches data for efficient training.
        
        Args:
            mother_dataset: Base dataset containing rollout data
            idx: Index identifier for this subdataset
            batch_size: Size of batches to create
            args: Configuration arguments
        """
        self.samples_per_rollout = args.time_step - args.len_seq
        self.mother = mother_dataset
        self.n_rollout = mother_dataset.n_rollout // args.n_splits
        self.idx = idx
        self.batch_size = batch_size
        self.args = args
        self.prepared_names = ['attrs', 'states', 'actions', 'rel_attrs']

        # Create a shuffled sample table
        n_samples = self.samples_per_rollout * self.n_rollout
        self.sample_table = []
        for i in range(self.n_rollout):
            for j in range(self.samples_per_rollout):
                self.sample_table.append((i, j))
        
        # Shuffle the sample table
        random.shuffle(self.sample_table)
        
        # Truncate to ensure full batches
        n_samples = (n_samples // batch_size) * batch_size
        self.sample_table = self.sample_table[:n_samples]

    def __len__(self):
        return len(self.sample_table) // self.batch_size

    def __getitem__(self, idx):
        # Load data for the current batch
        idx_rollout = self.sample_table[idx][0] + self.n_rollout * self.idx
        idx_timestep = self.sample_table[idx][1]

        # Prepare input sequence data
        seq_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(idx_rollout) + '.rollout.h5'))
        seq_data = [d[idx_timestep:idx_timestep + self.args.len_seq + 1] for d in seq_data]

        # Prepare fit data from another trajectory in the same parameter group
        fit_idx = rand_int(0, self.args.group_size - 1)  # new traj idx in group
        fit_idx = fit_idx + idx_rollout // self.args.group_size * self.args.group_size  # new traj idx in global
        fit_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(fit_idx) + '.rollout.h5'))

        return seq_data, fit_data


class SubPreparedDataset(Dataset):
    def __init__(self,
                mother_dataset,
                idx, 
                args):
        """
        Non-shuffled dataset subset for validation.
        
        Args:
            mother_dataset: Base dataset containing rollout data
            idx: Index identifier for this subdataset
            args: Configuration arguments
        """
        self.samples_per_rollout = args.time_step - args.len_seq
        self.mother = mother_dataset
        self.n_rollout = mother_dataset.n_rollout // args.n_splits
        self.idx = idx
        self.prepared_names = ['attrs', 'states', 'actions', 'rel_attrs']
        self.args = args

    def __len__(self):
        return self.n_rollout * self.samples_per_rollout

    def __getitem__(self, idx):
        idx_rollout = idx // self.samples_per_rollout + self.n_rollout * self.idx
        idx_timestep = idx % self.samples_per_rollout

        # Prepare input sequence data
        seq_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(idx_rollout) + '.rollout.h5'))
        seq_data = [d[idx_timestep:idx_timestep + self.args.len_seq + 1] for d in seq_data]

        # Prepare fit data from another trajectory in the same parameter group
        fit_idx = rand_int(0, self.args.group_size - 1)  # new traj idx in group
        fit_idx = fit_idx + idx_rollout // self.args.group_size * self.args.group_size  # new traj idx in global
        fit_data = load_data(self.prepared_names, os.path.join(self.mother.data_dir, str(fit_idx) + '.rollout.h5'))
        return seq_data, fit_data


def split_dataset(ds, args):
    """
    Split a dataset into multiple sub-datasets for parallel processing.
    
    Args:
        ds: Dataset to split
        args: Configuration arguments
        
    Returns:
        List of ShuffledDataset instances
    """
    assert ds.n_rollout % args.group_size == 0
    assert ds.n_rollout % args.n_splits == 0
    sub_datasets = [ShuffledDataset(mother_dataset=ds, 
                                    idx=i, 
                                    batch_size=args.batch_size, 
                                    args=args) for i in range(args.n_splits)]
    return sub_datasets



if __name__ == '__main__':
    mp.set_start_method('spawn')  # Explicitly set spawn method

    args = gen_args()

    os.system('mkdir -p ' + args.outf)
    os.system('mkdir -p ' + args.dataf)
    tee = Tee(os.path.join(args.outf, 'train.log'), 'w')
    print_args(args)

    # Generate or load data
    datasets = {phase: PhysicsDataset(args, phase) for phase in ['train', 'valid']}
    for phase in ['train', 'valid']:
        if args.gen_data:
            datasets[phase].gen_data()
        else:
            datasets[phase].load_data()

    if args.gen_data:
        print("Preprocessing data ...")
        os.system('python preprocess_data.py --env ' + args.env)

    args.stat = datasets['train'].stat

    # Check for GPU availability
    use_gpu = torch.cuda.is_available()

    """
    Setup dataloaders with dataset splitting for handling various numbers of objects
    """
    dataloaders = {}
    data_n_batches = {}
    loaders = {}

    for phase in ['train', 'valid']:
        loaders[phase] = [DataLoader(
            dataset=dataset, batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers)
            for dataset in split_dataset(datasets[phase], args)]

        dataloaders[phase] = lambda: mix_iters(iters=[iter(loader) for loader in loaders[phase]])

        num_batches = sum(len(loader) for loader in loaders[phase])
        data_n_batches[phase] = num_batches

    # Initialize model: Compositional Koopman Operator
    model = CompositionalKoopmanOperators(args, residual=False, use_gpu=use_gpu)

    # Print model parameter count
    print("Model #parameters: %d" % count_parameters(model))

    # Resume from checkpoint if specified
    if args.resume_epoch >= 0:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (args.resume_epoch, args.resume_iter))
        print("Loading saved checkpoint from %s" % model_path)
        model.load_state_dict(torch.load(model_path))

    # Loss criterion
    criterionMSE = nn.MSELoss()

    # Optimizer setup
    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, betas=(args.beta1, 0.999))
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.6, patience=2, verbose=True)

    # Move model to GPU if available
    if use_gpu:
        model = model.cuda()
        criterionMSE = criterionMSE.cuda()

    # Training setup
    st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
    best_valid_loss = np.inf

    log_fout = open(os.path.join(args.outf, 'log_st_epoch_%d.txt' % st_epoch), 'w')

    # Main training loop
    for epoch in range(st_epoch, args.n_epoch):
        phases = ['train', 'valid'] if args.eval == 0 else ['valid']

        for phase in phases:
            model.train(phase == 'train')
            
            # Initialize metrics
            meter_loss = AverageMeter()
            meter_loss_metric = AverageMeter()
            meter_loss_ae = AverageMeter()
            meter_loss_pred = AverageMeter()
            meter_fit_error = AverageMeter()
            meter_dist_g = AverageMeter()
            meter_dist_s = AverageMeter()

            bar = ProgressBar(max_value=data_n_batches[phase])
            loader = dataloaders[phase]()

            for i, (seq_data, fit_data) in bar(enumerate(loader)):
                attrs, states, actions, rel_attrs = seq_data
                attrs_2, states_2, actions_2, rel_attrs_2 = fit_data
                
                # Move fit data to GPU if available
                if use_gpu:
                    attrs_2, states_2, actions_2, rel_attrs_2 = [x.cuda() for x in fit_data]
                
                fit_data = [attrs_2, states_2, actions_2, rel_attrs_2]

                with torch.set_grad_enabled(phase == 'train'):
                    # Move sequence data to GPU if available
                    if use_gpu:
                        attrs, states, actions, rel_attrs = [x.cuda() for x in seq_data]
                    data = [attrs, states, actions, rel_attrs]

                    T = args.len_seq
                    bs = len(attrs)

                    """
                    Flatten fit data for system identification
                    """
                    attrs_flat = get_flat(attrs_2)
                    states_flat = get_flat(states_2)
                    actions_flat = get_flat(actions_2)
                    rel_attrs_flat = get_flat(rel_attrs_2)

                    # Encode states into latent space
                    g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep)
                    g = g.view(torch.Size([bs, args.time_step]) + g.size()[1:])

                    """
                    Fit A with fit data
                    !!! need to force that rel_attrs in one group to be the same !!!
                    """
                    G_tilde = g[:, :-1]
                    H_tilde = g[:, 1:]
                    U_left = actions_2[:, :-1]

                    G_tilde = get_flat(G_tilde, keep_dim=True)
                    H_tilde = get_flat(H_tilde, keep_dim=True)
                    U_left = get_flat(U_left, keep_dim=True)

                    # Perform system identification
                    A, B, fit_err = model.system_identify(G=G_tilde, H=H_tilde, U=U_left,
                                                        rel_attrs=rel_attrs[:1, 0], I_factor=args.I_factor)

                    model.A = model.A.repeat(bs, 1, 1)
                    model.B = model.B.repeat(bs, 1, 1)

                    meter_fit_error.update(fit_err.item(), bs)

                    """
                    Forward on sequential data
                    """
                    attrs_flat = get_flat(attrs)
                    states_flat = get_flat(states)
                    actions_flat = get_flat(actions)
                    rel_attrs_flat = get_flat(rel_attrs)

                    g = model.to_g(attrs_flat, states_flat, rel_attrs_flat, args.pstep)

                    # Calculate metric loss (contrastive-like)
                    permu = np.random.permutation(bs * (T + 1))
                    split_0 = permu[:bs * (T + 1) // 2]
                    split_1 = permu[bs * (T + 1) // 2:]

                    dist_g = torch.mean((g[split_0] - g[split_1]) ** 2, dim=(1, 2))
                    dist_s = torch.mean((states_flat[split_0] - states_flat[split_1]) ** 2, dim=(1, 2))
                    scaling_factor = 10
                    loss_metric = torch.abs(dist_g * scaling_factor - dist_s).mean()

                    g = g.view(torch.Size([bs, T + 1]) + g.size()[1:])

                    """
                    Rollout 0 -> 1 : T + 1
                    """
                    U_for_pred = actions[:, : T]
                    G_for_pred = model.simulate(T=T, g=g[:, 0], u_seq=U_for_pred, rel_attrs=rel_attrs[:, 0])

                    # Prepare data for autoencoding and prediction
                    data_for_ae = [x[:, :T + 1] for x in data]
                    data_for_pred = [x[:, 1:T + 1] for x in data]

                    # Decode state for auto-encoding
                    attrs_for_ae_flat = get_flat(data_for_ae[0])
                    rel_attrs_for_ae_flat = get_flat(data_for_ae[3])
                    decode_s_for_ae = model.to_s(attrs=attrs_for_ae_flat, gcodes=get_flat(g[:, :T + 1]),
                                                rel_attrs=rel_attrs_for_ae_flat, pstep=args.pstep)

                    # Decode state for prediction
                    attrs_for_pred_flat = get_flat(data_for_pred[0])
                    rel_attrs_for_pred_flat = get_flat(data_for_pred[3])
                    decode_s_for_pred = model.to_s(attrs=attrs_for_pred_flat, gcodes=get_flat(G_for_pred),
                                                rel_attrs=rel_attrs_for_pred_flat, pstep=args.pstep)

                    # Calculate losses
                    loss_auto_encode = F.l1_loss(
                        decode_s_for_ae, states[:, :T + 1].reshape(decode_s_for_ae.shape))
                    loss_prediction = F.l1_loss(
                        decode_s_for_pred, states[:, 1:].reshape(decode_s_for_pred.shape))

                    # Combine losses
                    loss = loss_auto_encode + loss_prediction + loss_metric * args.lambda_loss_metric

                    # Update metrics
                    meter_loss_metric.update(loss_metric.item(), bs)
                    meter_loss_ae.update(loss_auto_encode.item(), bs)
                    meter_loss_pred.update(loss_prediction.item(), bs)
                    meter_dist_g.update(dist_g.mean().item(), bs)
                    meter_dist_s.update(dist_s.mean().item(), bs)

                # Update total loss meter
                meter_loss.update(loss.item(), bs)

                # Backward pass and optimization (training only)
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                    optimizer.step()

                # Log progress
                if i % args.log_per_iter == 0:
                    log = '%s [%d/%d][%d/%d] Loss: %.6f (%.6f), sysid_error: %.6f (%.6f), loss_ae: %.6f (%.6f), loss_pred: %.6f (%.6f), ' \
                        'loss_metric: %.6f (%.6f)' % (
                        phase, epoch, args.n_epoch, i, data_n_batches[phase],
                        loss.item(), meter_loss.avg,
                        fit_err.item(), meter_fit_error.avg,
                        loss_auto_encode.item(), meter_loss_ae.avg,
                        loss_prediction.item(), meter_loss_pred.avg,
                        loss_metric.item(), meter_loss_metric.avg,
                        )

                    print()
                    print(log)
                    log_fout.write(log + '\n')
                    log_fout.flush()

                # Save checkpoint periodically during training
                if phase == 'train' and i % args.ckp_per_iter == 0:
                    torch.save(model.state_dict(), '%s/net_epoch_%d_iter_%d.pth' % (args.outf, epoch, i))

            # End of epoch summary
            log = '%s [%d/%d] Loss: %.4f, Best valid: %.4f' % (phase, epoch, args.n_epoch, meter_loss.avg, best_valid_loss)
            print(log)
            log_fout.write(log + '\n')
            log_fout.flush()

            # Update learning rate and save best model (validation only)
            if phase == 'valid' and not args.eval:
                scheduler.step(meter_loss.avg)
                if meter_loss.avg < best_valid_loss:
                    best_valid_loss = meter_loss.avg
                    torch.save(model.state_dict(), '%s/net_best.pth' % (args.outf))

    log_fout.close()
