import sys
import os
import argparse
import time
import yaml
import numpy as np
import shutil  # Add this import for copying files

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# If you have your custom modules (e.g., ldKAN),
# make sure they can be imported in a distributed run
from ldKAN.dkan_2d import DKAN_2D_Layer

###############################################################################
# Teacher Network Definition
###############################################################################
class TeacherNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        """
        Constructs a feed forward network with `num_layers` linear layers.
        Matches the definition in generate_teacher_network.py (excluding multiplier logic)
        """
        super(TeacherNN, self).__init__()
        layers = []
        if num_layers == 1:
            # Map directly to output_dim
            layers.append(nn.Linear(input_dim, output_dim))
        else:
            # First layer: input_dim -> hidden_dim
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.Tanh())
            # Hidden layers: hidden_dim -> hidden_dim
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.Tanh())
            # Final layer: hidden_dim -> output_dim
            layers.append(nn.Linear(hidden_dim, output_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

###############################################################################
# Student Network Definition (wrapped in DDP)
###############################################################################
class FeedForwardNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_inner_layers, INIT_SCALE, teacher_output_dim, n_chunks,
                 block_size_forward, block_size_backward, tile_size_forward, tile_size_backward):
        super(FeedForwardNet, self).__init__()
        self.fc_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        # Store the target output dimension needed to match the teacher
        self.target_output_dim = teacher_output_dim

        # Calculate the required output dim for the final DKAN layer
        # Smallest multiple of tile_size_forward that is >= teacher_output_dim
        student_final_layer_dim = ((teacher_output_dim - 1) // tile_size_forward + 1) * tile_size_forward

        # First layer - use tile sizes
        self.fc_layers.append(DKAN_2D_Layer(
            n_chunks, input_dim, hidden_dim, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward,
            False, False, True, False, INIT_SCALE, True
        ))
        self.bn_layers.append(nn.BatchNorm1d(hidden_dim, affine=False))

        # Inner layers - use tile sizes
        for _ in range(n_inner_layers):
            self.fc_layers.append(DKAN_2D_Layer(
                n_chunks, hidden_dim, hidden_dim, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward,
                False, False, True, False, INIT_SCALE, True
            ))
            self.bn_layers.append(nn.BatchNorm1d(hidden_dim, affine=False))

        # Final layer - use tile sizes
        self.fc_layers.append(DKAN_2D_Layer(
            n_chunks, hidden_dim, student_final_layer_dim, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward,
            False, False, True, False, INIT_SCALE, True
        ))

    def forward(self, x, weight_dkan):
        x = x.transpose(0, 1).contiguous()
        # Process all layers except the final one with BN + activation
        for i in range(len(self.fc_layers) - 1):
            x = self.fc_layers[i](x, weight_dkan, True)
            x = x.transpose(0, 1)
            x = self.bn_layers[i](x)
            x = x.transpose(0, 1)

        # Final layer, no BN
        x = self.fc_layers[-1](x, weight_dkan, False)

        x = x.transpose(0, 1).contiguous()
        # Slice to match the teacher's output dimension
        x = x[:, :self.target_output_dim]
        x = x.contiguous()
        return x

    def get_frobenius_regularization(self):
        reg = 0.0
        for fc in self.fc_layers:
            reg += fc.get_frobenius_regularization()
        return reg

###############################################################################
# Training Function (called by each process)
###############################################################################
def ddp_train(args):
    """
    args: parsed command-line arguments
    """
    # Get rank and world size from environment variables
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    print("my rank is", rank)
    print("my world size is", world_size)

    # Initialize process group using environment variables set by torchrun
    # This automatically handles MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE
    # and prevents port conflicts (EADDRINUSE)
    dist.init_process_group(backend="nccl", init_method='env://')

    # Set device for this process
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # Load config
    CALC_SPECIFICATION_PATH = args.calc_specification_path
    CALC_FOLDER = args.calc_folder

    start_time = time.time()
    with open(CALC_SPECIFICATION_PATH, 'r') as file:
        config = yaml.safe_load(file)

    # Extract variables from the configuration
    INIT_SCALE = float(config['INIT_SCALE'])
    PURE_MLP_STEPS = config['PURE_MLP_STEPS']
    DKAN_TURN_ON_STEPS = config['DKAN_TURN_ON_STEPS']
    DKAN_TURN_ON_SCALE = config['DKAN_TURN_ON_SCALE']
    DKAN_TURN_ON_CAP = float(config['DKAN_TURN_ON_CAP'])
    DKAN_FROBENIUS_DECAY_STEPS = config['DKAN_FROBENIUS_DECAY_STEPS']
    DKAN_FROBENIUS_DECAY_SCALE = config['DKAN_FROBENIUS_DECAY_SCALE']
    FROBENIUS_WEIGHT_CAP = float(config['FROBENIUS_WEIGHT_CAP'])
    DKAN_LEARNING_RATE_DECAY_STEPS = config['DKAN_LEARNING_RATE_DECAY_STEPS']
    DKAN_LEARNING_RATE_DECAY_SCALE = config['DKAN_LEARNING_RATE_DECAY_SCALE']
    INITIAL_FROBENIUS_WEIGHT = float(config['INITIAL_FROBENIUS_WEIGHT'])
    BATCH_SIZE = config['BATCH_SIZE']   # This is per-GPU batch size
    HIDDEN_DIM = config['HIDDEN_DIM']
    TEACHER_PATH = config['TEACHER_PATH']
    N_INNER_LAYERS = config['N_INNER_LAYERS']
    RECORD_INTERVAL = config['RECORD_INTERVAL']
    DKAN_BASE_LR = float(config['DKAN_BASE_LR'])
    # Load N_CHUNKS from config
    N_CHUNKS = config['N_CHUNKS']
    # Load block sizes from config
    BLOCK_SIZE_FORWARD = config['BLOCK_SIZE_FORWARD']
    BLOCK_SIZE_BACKWARD = config['BLOCK_SIZE_BACKWARD']
    # Load tile sizes from config
    TILE_SIZE_FORWARD = config['TILE_SIZE_FORWARD']
    TILE_SIZE_BACKWARD = config['TILE_SIZE_BACKWARD']

    # Load teacher hypers
    hypers_path = os.path.join(TEACHER_PATH, "teacher_hypers.yaml")
    with open(hypers_path, 'r') as f:
        teacher_hypers = yaml.safe_load(f)

    TEACHER_INPUT_DIM = teacher_hypers['TEACHER_INPUT_DIM']
    TEACHER_HIDDEN_DIM = teacher_hypers['TEACHER_HIDDEN_DIM']
    TEACHER_NUM_LAYERS = teacher_hypers['TEACHER_NUM_LAYERS']
    # Read the teacher's output dimension
    TEACHER_OUTPUT_DIM = teacher_hypers.get('TEACHER_OUTPUT_DIM', 1) # Default to 1 if not found for backward compatibility

    # Create teacher model on this rank's GPU (not wrapped in DDP)
    teacher = TeacherNN(
        TEACHER_INPUT_DIM, TEACHER_HIDDEN_DIM, TEACHER_NUM_LAYERS, TEACHER_OUTPUT_DIM
    ).to(device)
    weights_path = os.path.join(TEACHER_PATH, "teacher_weights.pth")
    teacher.load_state_dict(torch.load(weights_path, map_location=device))
    teacher.eval()

    # Create student model, move to GPU, pass all loaded params
    student_model = FeedForwardNet(
        TEACHER_INPUT_DIM, HIDDEN_DIM, N_INNER_LAYERS, INIT_SCALE, TEACHER_OUTPUT_DIM, N_CHUNKS,
        BLOCK_SIZE_FORWARD, BLOCK_SIZE_BACKWARD, TILE_SIZE_FORWARD, TILE_SIZE_BACKWARD
    ).to(device)

    # Wrap student model in DDP
    ddp_model = DDP(student_model, device_ids=[rank], output_device=rank)

    # Prepare training schedule
    def get_params(step):
        if step < PURE_MLP_STEPS:
            lr = 1e-3
            dkan_weight = 0.0
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT
        elif step < PURE_MLP_STEPS + DKAN_TURN_ON_STEPS:
            offset = step - PURE_MLP_STEPS
            lr = DKAN_BASE_LR
            dkan_weight = min(offset / DKAN_TURN_ON_SCALE, DKAN_TURN_ON_CAP)
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT
        elif step < PURE_MLP_STEPS + DKAN_TURN_ON_STEPS + DKAN_FROBENIUS_DECAY_STEPS:
            offset = step - (PURE_MLP_STEPS + DKAN_TURN_ON_STEPS)
            lr = DKAN_BASE_LR
            dkan_weight = min(DKAN_TURN_ON_STEPS / DKAN_TURN_ON_SCALE, DKAN_TURN_ON_CAP)
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT / (10 ** (offset / DKAN_FROBENIUS_DECAY_SCALE))
            if frobenius_weight < FROBENIUS_WEIGHT_CAP:
                frobenius_weight = FROBENIUS_WEIGHT_CAP
        else:
            offset = step - (PURE_MLP_STEPS + DKAN_TURN_ON_STEPS + DKAN_FROBENIUS_DECAY_STEPS)
            num_halvings = offset // DKAN_LEARNING_RATE_DECAY_SCALE
            lr = DKAN_BASE_LR * (0.5 ** num_halvings)
            dkan_weight = min(DKAN_TURN_ON_STEPS / DKAN_TURN_ON_SCALE, DKAN_TURN_ON_CAP)
            frobenius_weight = (
                INITIAL_FROBENIUS_WEIGHT /
                (10 ** (DKAN_FROBENIUS_DECAY_STEPS / DKAN_FROBENIUS_DECAY_SCALE))
            )
            if frobenius_weight < FROBENIUS_WEIGHT_CAP:
                frobenius_weight = FROBENIUS_WEIGHT_CAP
        return lr, dkan_weight, frobenius_weight

    total_steps = (
        PURE_MLP_STEPS
        + DKAN_TURN_ON_STEPS
        + DKAN_FROBENIUS_DECAY_STEPS
        + DKAN_LEARNING_RATE_DECAY_STEPS
    )

    # Create optimizer with a placeholder LR (will be updated each step)
    init_lr, _, _ = get_params(0)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(ddp_model.parameters(), lr=init_lr)

    # For logging
    step_history, pure_loss_history, loss_history = [], [], []
    lr_history, dkan_weight_history, frobenius_weight_history = [], [], []

    # IMPORTANT: set a different random seed for each rank so that the
    # random inputs differ across GPUs.
    torch.manual_seed(1234 + rank)

    if rank == 0:
        os.makedirs(CALC_FOLDER, exist_ok=True)

    # Training loop
    for step in range(total_steps):
        # Get current schedule
        lr, dkan_weight, frobenius_weight = get_params(step)

        # Update optimizer LR
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Generate random inputs on this rank
        inputs = torch.randn(BATCH_SIZE, TEACHER_INPUT_DIM, device=device)

        # Teacher forward on this rank
        with torch.no_grad():
            targets = teacher(inputs)

        # Student forward
        outputs = ddp_model(inputs, dkan_weight)

        # Compute loss
        pure_loss = criterion(outputs, targets)
        fro_reg = student_model.get_frobenius_regularization()
        loss = pure_loss + frobenius_weight * fro_reg

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Optionally record
        if step % RECORD_INTERVAL == 0:
            # Print from rank 0 only, to avoid clutter
            if rank == 0:
                elapsed_time = time.time() - start_time
                print(
                    f"[Rank {rank}] Step {step}/{total_steps}, "
                    f"Loss: {loss.item():.8f}, "
                    f"Pure Loss: {pure_loss.item():.8f}, "
                    f"LR: {lr:.6f}, "
                    f"dkan_weight: {dkan_weight:.4f}, "
                    f"frobenius_weight: {frobenius_weight:.4f}, "
                    f"Elapsed: {elapsed_time:.2f}s"
                )
            step_history.append(step)
            pure_loss_history.append(pure_loss.item())
            loss_history.append(loss.item())
            lr_history.append(lr)
            dkan_weight_history.append(dkan_weight)
            frobenius_weight_history.append(frobenius_weight)

    # After training, we can do evaluation. We'll do a simple average
    # of losses across ranks, or just print them out. For demonstration,
    # we just compute them locally and reduce to rank 0.
    batch_size_eval = 1024 * 1024
    num_eval_steps = 10
    local_eval_losses = []

    # Make sure each rank uses a different seed again for eval
    torch.manual_seed(5678 + rank)

    with torch.no_grad():
        for s in range(num_eval_steps):
            eval_inputs = torch.randn(batch_size_eval, TEACHER_INPUT_DIM, device=device)
            eval_targets = teacher(eval_inputs)
            eval_outputs = ddp_model(eval_inputs, dkan_weight)  # reuse last dkan_weight
            eval_loss = criterion(eval_outputs, eval_targets)
            local_eval_losses.append(eval_loss.item())

    # Convert local evaluation losses to tensor and reduce (mean) across ranks
    local_eval_losses_t = torch.tensor(local_eval_losses, device=device)
    dist.all_reduce(local_eval_losses_t, op=dist.ReduceOp.SUM)
    global_eval_losses = local_eval_losses_t / world_size
    # Now each rank has the summed evaluation metrics, we can compute the final aggregated loss
    aggregated_loss = float(torch.mean(global_eval_losses))

    # Save model and logs only on rank 0
    if rank == 0:
        os.makedirs(CALC_FOLDER, exist_ok=True)
        # Save history
        history = {
            'step_history': step_history,
            'pure_loss_history': pure_loss_history,
            'loss_history': loss_history,
            'lr_history': lr_history,
            'dkan_weight_history': dkan_weight_history,
            'frobenius_weight_history': frobenius_weight_history
        }
        np.savez(os.path.join(CALC_FOLDER, 'history.npz'), **history)

        # Save student model weights
        torch.save(ddp_model.module.state_dict(), os.path.join(CALC_FOLDER, 'model_weights.pth'))

        # Compute total fitting time
        total_time = time.time() - start_time

        # Write summary
        summary_path = os.path.join(CALC_FOLDER, 'summary.txt')
        with open(summary_path, 'w') as f:
            f.write(f"final loss: {aggregated_loss}\n")
            f.write(f"fitting time: {total_time}\n")
            # Add GPU count and effective batch size information
            effective_batch_size = BATCH_SIZE * world_size
            f.write(f"number of GPUs: {world_size}\n")
            f.write(f"effective batch size: {effective_batch_size}\n")

        # Copy the specification YAML file to the calculation folder
        spec_filename = os.path.basename(CALC_SPECIFICATION_PATH)
        shutil.copy(CALC_SPECIFICATION_PATH, os.path.join(CALC_FOLDER, spec_filename))
        print(f"Copied specification file to {CALC_FOLDER}")

    # Cleanup
    dist.destroy_process_group()

###############################################################################
# Parse arguments and spawn processes
###############################################################################
def main():
    parser = argparse.ArgumentParser(
        description='Train a neural network using a teacher network with DDP.'
    )
    parser.add_argument(
        '--calc_specification_path',
        type=str,
        required=True,
        help='Path to the YAML configuration file'
    )
    parser.add_argument(
        '--calc_folder',
        type=str,
        required=True,
        help='Output folder for calculation results'
    )
    # For torchrun, we typically have env variables for rank / world size,
    # but we might also want to parse them as arguments (optional).
    args = parser.parse_args()

    # If launched via torchrun, these environment variables are set:
    #   WORLD_SIZE
    #   RANK
    #   LOCAL_RANK
    #   MASTER_ADDR
    #   MASTER_PORT
    #
    # NOTE: Be aware that the effective batch size will be BATCH_SIZE * world_size
    # For optimal training, you might want to scale learning rates accordingly
    # when using multiple GPUs compared to the single-GPU version.

    # Removed rank/world_size retrieval here, as it's now handled within ddp_train
    # which uses init_method='env://'

    # Call the ddp_train function for this process
    ddp_train(args) # Pass only args

if __name__ == "__main__":
    main()
