import os
import json
import torch
import torch.nn as nn
from decentralized_opt import log
import experiment_utils as utils

torch.set_default_dtype(torch.float64)

parser = utils.get_parser()
args = utils.parse_args(parser)


# === Initialize the distributed environment ===
if args.cpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    args.device = torch.device("cpu")
else:
    if args.single_gpu_sequential:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"
        args.device = torch.device("cuda:0")
    else:
        if 'CUDA_VISIBLE_DEVICES' in os.environ:
            devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
            os.environ['CUDA_VISIBLE_DEVICES'] = devices[args.rank % len(devices)]
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank % 8)
        args.device = torch.device("cuda")

rank = getattr(args, 'rank', 0)
world_size = getattr(args, 'world_size', 1)
if args.device.type == 'cuda':
    log.info(f"[Rank {args.rank}] CUDA device name: {torch.cuda.get_device_name(args.device)}")

utils.init(args)


# === Initialize the model and training set ===
if args.dataset == 'mnist':
    train_loader, val_loader, full_train_loader = (
        utils.load_mnist(args.rank, args.world_size, args.batch_size, args.val_batch_size, sort=args.sort))
elif args.dataset == 'fashionmnist':
    train_loader, val_loader, full_train_loader = (
        utils.load_fashion_mnist(args.rank, args.world_size, args.batch_size, args.val_batch_size, sort=args.sort))
else:
    raise NotImplementedError(f"Dataset {args.dataset} not supported.")

log.info(train_loader.dataset.targets.unique())
log.info(train_loader.dataset.targets.shape[0])

classes = [int(i) for i in range(10)]
criterion = nn.CrossEntropyLoss()

all_train_res = {}
all_val_res = {}
all_training_res = {}

for run in range(args.runs):
    if args.rank == 0:
        log.info(f"[Rank {args.rank}] ========================= Starting run {run} =========================")

    if args.model == 'net':
        model = utils.Net().to(args.device)
    elif args.model == 'biggernet':
        model = utils.BiggerNet().to(args.device)
    elif args.model == 'kerasnet':
        model = utils.KerasNet().to(args.device)
    elif args.model == 'csdnnet':
        model = utils.CSDNNET().to(args.device)
    elif args.model == 'resnet':
        model = utils.ResNet().to(args.device)
    elif args.model == 'lenet':
        model = utils.LeNet().to(args.device)
    else:
        raise NotImplementedError(f"Model {args.model} not supported.")

    model, optimizer = utils.wrap_model(model, args)
    log.info('Model is on %s, size %d', next(model.parameters()).device, model.flat_parameters.shape[0])

    # === Initialize the optimizer ===
    if args.optimizer == 'DNSGD':
        model.train()
        model.zero_grad()
        for i, (data, target) in enumerate(train_loader):
            target = target.to(device=args.device, non_blocking=True)
            data = data.to(device=args.device, non_blocking=True)
            output = model.module(data)
            loss = criterion(output, target)
            loss.backward()
        optimizer.init()
        model.zero_grad()

        args.results_dir = (
            f"results/{args.optimizer}/"
            f"{args.optimizer}_{args.dataset}_sort({args.sort})_{args.model}_"
            f"{args.world_size}_agents_{args.graph_type}_graph_lr_{args.lr:.6f}_"
            f"epochs_{args.epochs:d}_batch_size_{args.batch_size:d}_K_{args.K:d}_Khat_{args.Khat:d}_"
            f"runs_{args.runs:d}"
        )

    elif args.optimizer == 'DSGD':
        args.results_dir = (
            f"results/{args.optimizer}/"
            f"{args.optimizer}_{args.dataset}_sort({args.sort})_{args.model}_"
            f"{args.world_size}_agents_{args.graph_type}_graph_lr_{args.lr:.6f}_"
            f"epochs_{args.epochs:d}_batch_size_{args.batch_size:d}_"
            f"runs_{args.runs:d}"
        )

    elif args.optimizer == 'DSGT':
        model.train()
        model.zero_grad()
        for i, (data, target) in enumerate(train_loader):
            target = target.to(device=args.device, non_blocking=True)
            data = data.to(device=args.device, non_blocking=True)
            output = model.module(data)
            loss = criterion(output, target)
            loss.backward()
        optimizer.init()
        model.zero_grad()

        args.results_dir = (
            f"results/{args.optimizer}/"
            f"{args.optimizer}_{args.dataset}_sort({args.sort})_{args.model}_"
            f"{args.world_size}_agents_{args.graph_type}_graph_lr_{args.lr:.6f}_"
            f"epochs_{args.epochs:d}_batch_size_{args.batch_size:d}_"
            f"runs_{args.runs:d}"
        )

    elif args.optimizer == 'DNASA':
        model.train()
        model.zero_grad()
        for i, (data, target) in enumerate(train_loader):
            target = target.to(device=args.device, non_blocking=True)
            data = data.to(device=args.device, non_blocking=True)
            output = model.module(data)
            loss = criterion(output, target)
            loss.backward()
        optimizer.init()
        model.zero_grad()

        args.results_dir = (
            f"results/{args.optimizer}/"
            f"{args.optimizer}_{args.dataset}_sort({args.sort})_{args.model}_"
            f"{args.world_size}_agents_{args.graph_type}_graph_"
            f"epochs_{args.epochs:d}_batch_size_{args.batch_size:d}_"
            f"runs_{args.runs:d}"
        )

    # Training process, saving results
    train_res, val_res, training_res = utils.train(
        model, criterion, optimizer, train_loader, args,
        val_loader=val_loader,
        val_train_loader=full_train_loader,
        single_gpu_sequential=args.single_gpu_sequential
    )

    # === collect results ===
    all_train_res[run] = train_res
    all_val_res[run] = val_res
    all_training_res[run] = training_res

# === after all runs finish ===
def average_res(all_res_dict):
    """
    all_res_dict: {run_id: [(x, y1, y2, ...), (x, y1, y2, ...), ...]}
    return: [(x, avg_y1, avg_y2, ...), ...]
    """
    if not all_res_dict:
        return []

    base = list(all_res_dict.values())[0]
    avg_res = []
    for i, row in enumerate(base):
        x = row[0]
        cols = zip(*[res[i] for res in all_res_dict.values()])  # transpose
        cols = list(cols)
        avg_vals = [sum(c) / len(c) for c in cols[1:]]
        avg_res.append((x, *avg_vals))
    return avg_res

# === compute averages ===
train_res = average_res(all_train_res)
val_res = average_res(all_val_res)
training_res = average_res(all_training_res)

# === save final averaged results ===
args.results_dir = os.path.join(args.results_dir, f"rank_{args.rank}")
os.makedirs(args.results_dir, exist_ok=True)

def convert(obj):
    if hasattr(obj, "tolist"):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert(x) for x in obj]
    else:
        return obj

results = {
    "train": convert(train_res),
    "val": convert(val_res),
    "training": convert(training_res),
}

save_path = os.path.join(args.results_dir, "train_val_results.json")
with open(save_path, "w") as f:
    json.dump(results, f, indent=4)

log.info("Process %d saved averaged results to %s", args.rank, save_path)
