import os
import argparse
import time
import yaml
import numpy as np
import shutil  # Add this import for copying files

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

###############################################################################
# Teacher Network Definition
###############################################################################
class TeacherNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        """
        Constructs a feed forward network with `num_layers` linear layers.
        Matches the definition in generate_teacher_network.py (excluding multiplier logic)
        """
        super(TeacherNN, self).__init__()
        layers = []
        if num_layers == 1:
            # Map directly to output_dim
            layers.append(nn.Linear(input_dim, output_dim))
        else:
            # First layer: input_dim -> hidden_dim
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.Tanh())
            # Hidden layers: hidden_dim -> hidden_dim
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.Tanh())
            # Final layer: hidden_dim -> output_dim
            layers.append(nn.Linear(hidden_dim, output_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

###############################################################################
# Student Network Definition (wrapped in DDP)
###############################################################################
class FeedForwardNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_inner_layers, teacher_output_dim):
        super(FeedForwardNet, self).__init__()
        self.fc_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        # First layer
        self.fc_layers.append(nn.Linear(input_dim, hidden_dim))
        self.bn_layers.append(nn.BatchNorm1d(hidden_dim))

        # Inner layers
        for _ in range(n_inner_layers):
            self.fc_layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.bn_layers.append(nn.BatchNorm1d(hidden_dim))

        # Final layer - use teacher_output_dim
        self.fc_layers.append(nn.Linear(hidden_dim, teacher_output_dim))

    def forward(self, x):
        # Process all layers except the final one with BN + ReLU
        for i in range(len(self.fc_layers) - 1):
            x = self.fc_layers[i](x)
            x = self.bn_layers[i](x)
            x = F.relu(x)

        # Final layer, no BN or activation
        x = self.fc_layers[-1](x)
        return x

###############################################################################
# Training Function (called by each process)
###############################################################################
def ddp_train(args):
    """
    args: parsed command-line arguments
    """
    # Get rank and world size from environment variables
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    print("my rank is", rank)
    print("my world size is", world_size)
    # Initialize process group using environment variables set by torchrun
    # This automatically handles MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE
    # and prevents port conflicts (EADDRINUSE)
    dist.init_process_group(backend="nccl", init_method='env://')

    # Set device for this process
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # Load config
    CALC_SPECIFICATION_PATH = args.calc_specification_path
    CALC_FOLDER = args.calc_folder

    start_time = time.time()
    with open(CALC_SPECIFICATION_PATH, 'r') as file:
        config = yaml.safe_load(file)

    # Extract variables from the configuration
    BASE_LR = float(config['BASE_LR'])
    LEARNING_RATE_DECAY_SCALE = config['LEARNING_RATE_DECAY_SCALE']
    NUM_STEPS = config['NUM_STEPS']
    BATCH_SIZE = config['BATCH_SIZE']   # This is per-GPU batch size
    HIDDEN_DIM = config['HIDDEN_DIM']
    TEACHER_PATH = config['TEACHER_PATH']
    N_INNER_LAYERS = config['N_INNER_LAYERS']
    RECORD_INTERVAL = config['RECORD_INTERVAL']

    # Load teacher hypers
    hypers_path = os.path.join(TEACHER_PATH, "teacher_hypers.yaml")
    with open(hypers_path, 'r') as f:
        teacher_hypers = yaml.safe_load(f)

    TEACHER_INPUT_DIM = teacher_hypers['TEACHER_INPUT_DIM']
    TEACHER_HIDDEN_DIM = teacher_hypers['TEACHER_HIDDEN_DIM']
    TEACHER_NUM_LAYERS = teacher_hypers['TEACHER_NUM_LAYERS']
    # Read the teacher's output dimension
    TEACHER_OUTPUT_DIM = teacher_hypers.get('TEACHER_OUTPUT_DIM', 1) # Default to 1 if not found

    # Create teacher model on this rank's GPU (not wrapped in DDP)
    teacher = TeacherNN(
        TEACHER_INPUT_DIM, TEACHER_HIDDEN_DIM, TEACHER_NUM_LAYERS, TEACHER_OUTPUT_DIM
    ).to(device)
    weights_path = os.path.join(TEACHER_PATH, "teacher_weights.pth")
    teacher.load_state_dict(torch.load(weights_path, map_location=device))
    teacher.eval()

    # Create student model, move to GPU
    student_model = FeedForwardNet(
        TEACHER_INPUT_DIM, HIDDEN_DIM, N_INNER_LAYERS, TEACHER_OUTPUT_DIM
    ).to(device)

    # Wrap student model in DDP
    ddp_model = DDP(student_model, device_ids=[rank], output_device=rank)

    # Create optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(ddp_model.parameters(), lr=BASE_LR)

    # Create learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=LEARNING_RATE_DECAY_SCALE,
        gamma=0.5  # Decrease LR by factor of 2
    )

    # For logging
    step_history, pure_loss_history, loss_history = [], [], []
    lr_history = []

    # IMPORTANT: set a different random seed for each rank so that the
    # random inputs differ across GPUs.
    torch.manual_seed(1234 + rank)

    # Training loop
    for step in range(NUM_STEPS):
        # Get current learning rate for logging
        current_lr = optimizer.param_groups[0]['lr']

        # Generate random inputs on this rank
        inputs = torch.randn(BATCH_SIZE, TEACHER_INPUT_DIM, device=device)

        # Teacher forward on this rank
        with torch.no_grad():
            targets = teacher(inputs)

        # Student forward
        outputs = ddp_model(inputs)

        # Compute loss
        loss = criterion(outputs, targets)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Step the scheduler
        scheduler.step()

        # Optionally record
        if step % RECORD_INTERVAL == 0:
            # Print from rank 0 only, to avoid clutter
            if rank == 0:
                elapsed_time = time.time() - start_time
                print(
                    f"[Rank {rank}] Step {step}/{NUM_STEPS}, "
                    f"Loss: {loss.item():.8f}, "
                    f"LR: {current_lr:.6f}, "
                    f"Elapsed: {elapsed_time:.2f}s"
                )
            step_history.append(step)
            pure_loss_history.append(loss.item())
            loss_history.append(loss.item())
            lr_history.append(current_lr)

    # After training, we can do evaluation. We'll do a simple average
    # of losses across ranks, or just print them out. For demonstration,
    # we just compute them locally and reduce to rank 0.
    batch_size_eval = 1024 * 1024
    num_eval_steps = 10
    local_eval_losses = []

    # Make sure each rank uses a different seed again for eval
    torch.manual_seed(5678 + rank)

    with torch.no_grad():
        for s in range(num_eval_steps):
            eval_inputs = torch.randn(
                batch_size_eval, TEACHER_INPUT_DIM, device=device
            )
            eval_targets = teacher(eval_inputs)
            eval_outputs = ddp_model(eval_inputs)
            eval_loss = criterion(eval_outputs, eval_targets)
            local_eval_losses.append(eval_loss.item())

    # Convert local evaluation losses to tensor and reduce (mean) across ranks
    local_eval_losses_t = torch.tensor(local_eval_losses, device=device)
    dist.all_reduce(local_eval_losses_t, op=dist.ReduceOp.SUM)
    global_eval_losses = local_eval_losses_t / world_size
    # Now each rank has the summed evaluation metrics, we can compute the final
    # aggregated loss
    aggregated_loss = float(torch.mean(global_eval_losses))

    # Save model and logs only on rank 0
    if rank == 0:
        os.makedirs(CALC_FOLDER, exist_ok=True)
        # Save history
        history = {
            'step_history': step_history,
            'pure_loss_history': pure_loss_history,
            'loss_history': loss_history,
            'lr_history': lr_history
        }
        np.savez(os.path.join(CALC_FOLDER, 'history.npz'), **history)

        # Save student model weights
        torch.save(
            ddp_model.module.state_dict(),
            os.path.join(CALC_FOLDER, 'model_weights.pth')
        )

        # Compute total fitting time
        total_time = time.time() - start_time

        # Write summary
        summary_path = os.path.join(CALC_FOLDER, 'summary.txt')
        with open(summary_path, 'w') as f:
            f.write(f"final loss: {aggregated_loss}\n")
            f.write(f"fitting time: {total_time}\n")
            # Add GPU count and effective batch size information
            effective_batch_size = BATCH_SIZE * world_size
            f.write(f"number of GPUs: {world_size}\n")
            f.write(f"effective batch size: {effective_batch_size}\n")

        # Copy the specification YAML file to the calculation folder
        spec_filename = os.path.basename(CALC_SPECIFICATION_PATH)
        shutil.copy(
            CALC_SPECIFICATION_PATH,
            os.path.join(CALC_FOLDER, spec_filename)
        )
        print(f"Copied specification file to {CALC_FOLDER}")

    # Cleanup
    dist.destroy_process_group()

###############################################################################
# Parse arguments and spawn processes
###############################################################################
def main():
    parser = argparse.ArgumentParser(
        description='Train a neural network using a teacher network with DDP.'
    )
    parser.add_argument(
        '--calc_specification_path',
        type=str,
        required=True,
        help='Path to the YAML configuration file'
    )
    parser.add_argument(
        '--calc_folder',
        type=str,
        required=True,
        help='Output folder for calculation results'
    )
    # For torchrun, we typically have env variables for rank / world size,
    # but we might also want to parse them as arguments (optional).
    args = parser.parse_args()

    # If launched via torchrun, these environment variables are set:
    #   WORLD_SIZE
    #   RANK
    #   LOCAL_RANK
    #
    # NOTE: Be aware that the effective batch size will be BATCH_SIZE * world_size
    # For optimal training, you might want to scale learning rates
    # accordingly when using multiple GPUs compared to
    # the single-GPU version.

    # Removed rank/world_size retrieval here, as it's now handled within ddp_train
    # which uses init_method='env://'

    # Call the ddp_train function for this process
    ddp_train(args) # Pass only args

if __name__ == "__main__":
    main()
