import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models.flow2flow import Flow2Flow
from torch.utils.data import Dataset
import numpy as np

class Args:
    def __init__(self, input_dim):
        self.gpu_ids = [0]
        self.is_training = True
        self.num_channels = input_dim
        self.num_scales = 2
        self.num_channels_g = 32
        self.num_blocks = 4
        self.initializer = 'xavier'
        self.lambda_mle = 1e-4
        self.clip_gradient = 1.0
        self.rnvp_lr = 0.001
        self.rnvp_beta_1 = 0.9
        self.rnvp_beta_2 = 0.999
        self.lr = 0.001
        self.beta_1 = 0.5
        self.beta_2 = 0.999
        self.use_least_squares = True
        self.jc_lambda_min = 0.1
        self.jc_lambda_max = 10.0
        self.use_mixer = False
        self.clamp_jacobian = True
        self.weight_norm_l2 = 0.01
        self.num_coupling_layers = 4
        self.input_dim = input_dim
        self.device = 'cuda'
        self.hidden_dim = 128
        self.lr_policy = 'step'
        self.lr_step_epochs = 50

class Flow2FlowAgent:
    def __init__(self, input_dim):
        self.args = Args(input_dim)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = Flow2Flow(self.args).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        self.best_val_loss = float('inf')
        self.best_epoch = -1
        self.patience = 10   
        self.patience_counter = 0

        self.best_state_dict = None
    
    def split_labeled_data(self, data_list, val_ratio=0.2):
        import random
        n = len(data_list)
        n_val = int(n * val_ratio)
        random.shuffle(data_list)
        val_part = data_list[:n_val]
        train_part = data_list[n_val:]
        return train_part, val_part

    def set_labeled_data(self, labeled_data_a, labeled_data_b, val_ratio=0.2):
        train_a, val_a = self.split_labeled_data(labeled_data_a, val_ratio)
        train_b, val_b = self.split_labeled_data(labeled_data_b, val_ratio)

        for (x_i, y_i) in train_a:
            self.model.add_labeled_data_a(x_i, y_i)
        for (x_i, y_i) in train_b:
            self.model.add_labeled_data_b(x_i, y_i)

        for (x_i, y_i) in val_a:
            self.model.add_labeled_data_a_val(x_i, y_i)
        for (x_i, y_i) in val_b:
            self.model.add_labeled_data_b_val(x_i, y_i)

    def fit(self, train_dataset, val_dataset, epochs, batch_size):
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        for epoch in range(epochs):
            # =========== 1) AlignFlow with unlabeled data ===========
            self.model.train()
            for batch in train_loader:
                src_input = batch['src'].to(self.device)
                tgt_input = batch['tgt'].to(self.device)
                self.model.set_inputs(src_input, tgt_input)
                self.model.train_iter_alignflow()

            # =========== 2) Predictor with labeled data (one pass) ===========
            self.model.train_iter_all_one_epoch(batch_size=batch_size)

            # =========== 3) VALIDATION ===========
            self.model.eval()
            val_loss_g = 0.0
            val_loss_d = 0.0
            val_count = 0

            with torch.no_grad():
                for batch in val_loader:
                    src_input = batch['src'].to(self.device)
                    tgt_input = batch['tgt'].to(self.device)
                    self.model.set_inputs(src_input, tgt_input)
                    loss_dict = self.model.compute_losses_no_grad()
                    val_loss_g += loss_dict['loss_g']
                    val_loss_d += loss_dict['loss_d']
                    val_count += 1

            if val_count > 0:
                val_loss_g /= val_count
                val_loss_d /= val_count

            # predictor validation
            val_pred_loss = self.model.validate_predictor_loss()
            val_total_loss = val_loss_g + val_loss_d + val_pred_loss

            print(
                f"Epoch [{epoch+1}/{epochs}], "
                f"Val LossG={val_loss_g:.4f}, "
                f"Val LossD={val_loss_d:.4f}, "
                f"Val Pred={val_pred_loss:.4f}, "
                f"Val Total={val_total_loss:.4f}"
            )

            # save_path = "checkpoints/my_flow2flow"
            # self.save_model(save_path, epoch=epoch)

            # =========== 4) Early Stopping ===========
            MIN_EPOCH = 20
            if (epoch + 1) >= MIN_EPOCH:
                if val_total_loss < self.best_val_loss:
                    self.best_val_loss = val_total_loss
                    self.best_epoch = epoch
                    self.patience_counter = 0
                    self.best_state_dict = {
                        'model': self.model.state_dict(),
                    }
                else:
                    self.patience_counter += 1

                if self.patience_counter >= self.patience:
                    print(f"[EarlyStop] val_total_loss not improved for {self.patience} epochs")
                    print(f"Best epoch={self.best_epoch+1}, val_total_loss={self.best_val_loss:.4f}")
                    if self.best_state_dict is not None:
                        self.model.load_state_dict(self.best_state_dict['model'])
                        print("Restored best weights.")
                    break

        print("Training done.")
    
    def genAfromB(self, b_vectors):
        self.model.eval()
        with torch.no_grad():
            b_vectors = b_vectors.to(self.device, dtype=torch.float32)
            z, _ = self.model.g_tgt(b_vectors, reverse=False)  # B -> Z
            a_vectors, _ = self.model.g_src(z, reverse=True)  # Z -> A
        return a_vectors.cpu()
    
    def genBfromA(self, a_vectors):
        self.model.eval()
        with torch.no_grad():
            a_vectors = a_vectors.to(self.device, dtype=torch.float32)
            z, _ = self.model.g_src(a_vectors, reverse=False)  # A -> Z
            b_vectors, _ = self.model.g_tgt(z, reverse=True)  # Z -> B
        return b_vectors.cpu()
    
    def genAfromZ(self, z_vectors):
        self.model.eval()
        with torch.no_grad():
            z_vectors = z_vectors.to(self.device, dtype=torch.float32)
            a_vectors, _ = self.model.g_src(z_vectors, reverse=True)  # Z -> A
        return a_vectors.cpu()
    
    def genBfromZ(self, z_vectors):
        self.model.eval()
        with torch.no_grad():
            z_vectors = z_vectors.to(self.device, dtype=torch.float32)
            b_vectors, _ = self.model.g_tgt(z_vectors, reverse=True)  # Z -> B
        return b_vectors.cpu()
    
    def save_model(self, path, epoch=None):
        if epoch is not None:
            save_path = f"{path}_epoch{epoch}.pth"
        else:
            save_path = path

        torch.save(self.model.state_dict(), save_path)
        print(f"Model saved to {save_path}")
    
    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.model.to(self.device)
        print(f"Model loaded from {path}")

class RandomDataset(Dataset):
    def __init__(self, num_samples, input_dim):
        self.data = torch.randn(num_samples, input_dim)  
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {'src': self.data[idx]} 


if __name__ == '__main__':
    input_dim = 16
    num_samples_a = 1000
    num_samples_b = 1000

    dataset_a = RandomDataset(num_samples_a, input_dim)
    dataset_b = RandomDataset(num_samples_b, input_dim)

    agent = Flow2FlowAgent(input_dim=input_dim)

    print("Starting training...")
    agent.fit(dataset_a, dataset_b, epochs=5, batch_size=32)

    print("Testing generation functions...")
    b_sample = torch.randn(1, input_dim)  
    a_from_b = agent.genAfromB(b_sample)  
    print("Generated A from B:", a_from_b)

    a_sample = torch.randn(1, input_dim) 
    b_from_a = agent.genBfromA(a_sample)  
    print("Generated B from A:", b_from_a)

    z_sample = torch.randn(1, input_dim) 
    a_from_z = agent.genAfromZ(z_sample)  
    print("Generated A from Z:", a_from_z)

    b_from_z = agent.genBfromZ(z_sample)  
    print("Generated B from Z:", b_from_z)

    print("Saving model...")
    agent.save_model("flow2flow_model.pth")

    print("Loading model...")
    agent.load_model("flow2flow_model.pth")

    print("All tests completed successfully.")