import os
import time
import random
import numpy as np
from datetime import timedelta
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from utils import *
from utils.load_train_setting import *  
from network.Network import Network

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def is_dist_avail_and_initialized():
    return dist.is_available() and dist.is_initialized()

def get_rank():
    return dist.get_rank() if is_dist_avail_and_initialized() else 0

def get_world_size():
    return dist.get_world_size() if is_dist_avail_and_initialized() else 1

def is_main_process():
    return get_rank() == 0

def all_reduce_sum_scalar(v: float, device) -> float:
    t = torch.tensor([float(v)], device=device)
    if is_dist_avail_and_initialized():
        dist.all_reduce(t, op=dist.ReduceOp.SUM)
    return float(t.item())

def all_reduce_meter_sums(meter: dict, device) -> dict:
    out = {}
    for k, v in meter.items():
        out[k] = all_reduce_sum_scalar(v, device)
    return out


def main():

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    rank       = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    if world_size > 1 and not is_dist_avail_and_initialized():
        dist.init_process_group(
            backend="nccl",
            timeout=timedelta(minutes=30)
        )
        torch.cuda.set_device(local_rank)

    set_seed(42 + rank)
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")

    network = Network(
        noise_layers=noise_layers,
        device=device,
        lr=lr,
        accum_steps=accum_steps,
        use_ddp=(world_size > 1),
        use_amp=False,            
    )


    train_dataset = CoCoDataset(os.path.join(dataset_path, "train"), H, W)
    val_dataset   = CoCoDataset(os.path.join(dataset_path, "test"),  H, W)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) if world_size > 1 else None
    val_sampler   = DistributedSampler(val_dataset,   num_replicas=world_size, rank=rank, shuffle=False) if world_size > 1 else None

    num_workers = 8  
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=(num_workers > 0),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=(num_workers > 0),
    )


    start_epoch = 0
    total_epochs = start_epoch + epoch_number

    if train_continue:
        model_dir = os.path.join("results", train_continue_path, "models")
        ckpt_path = os.path.join(model_dir, f"EC_{train_continue_epoch}.pth")
        print(f"=> Resume from {ckpt_path}")
        network.load_model(ckpt_path)


    result_dir = result_folder.rstrip("/")
    if is_main_process():
        ensure_dir(result_dir)
        ensure_dir(os.path.join(result_dir, "models"))
        ensure_dir(os.path.join(result_dir, "samples"))
    if is_dist_avail_and_initialized():
        dist.barrier()

    train_log_path = os.path.join(result_dir, "train_log.txt")
    val_log_path   = os.path.join(result_dir, "val_log.txt")

    if is_main_process():
        print("\nStart training:\n")


    psnr_hist = [] 
    def psnr_stable_last_k(hist, k: int = 5, tol: float = 1.0) -> bool:

        if len(hist) < k:
            return False
        w = hist[-k:]
        return (max(w) - min(w)) <= tol


    for epoch in range(start_epoch, total_epochs):
        if world_size > 1:
            train_sampler.set_epoch(epoch)

        epoch_start_time = time.time()

        train_meter = {
            "acc": 0.0,
            "encoder_weight": 0.0,
            "psnr": 0.0,
            "g_loss": 0.0,
            "g_loss_on_encoder": 0.0,
            "g_loss_on_decoder": 0.0,
        }
        num_batches_local = 0


        network.encoder_decoder.train()
        pbar = tqdm(train_loader, total=len(train_loader), desc=f"Train Epoch {epoch}", disable=not is_main_process())

        for batch_images in pbar:
            images = batch_images.to(device, non_blocking=True)
            B = images.size(0)

            messages = torch.randint(0, 2, (B, message_length), device=device, dtype=torch.float32)

            result = network.train(epoch, images, messages)

            for k, v in train_meter.items():
                if k in result:
                    train_meter[k] += float(result[k])
            num_batches_local += 1

            if is_main_process():
                pbar.set_postfix({
                    "acc": f"{train_meter['acc']/max(1,num_batches_local):.4f}",
                    "psnr": f"{train_meter['psnr']/max(1,num_batches_local):.2f}",
                    "loss": f"{train_meter['g_loss']/max(1,num_batches_local):.4f}",

                })

        network.flush_accum()
        
        if world_size > 1:
            train_meter = all_reduce_meter_sums(train_meter, device)
            total_batches_global = all_reduce_sum_scalar(num_batches_local, device)
        else:
            total_batches_global = num_batches_local

        if is_main_process():
            train_elapsed = int(time.time() - epoch_start_time)
            train_psnr_epoch_avg = train_meter['psnr'] / max(1, total_batches_global)
            psnr_hist.append(float(train_psnr_epoch_avg))

            train_content = f"Epoch {epoch} (train) : {train_elapsed}s\n"
            for k, v in train_meter.items():
                train_content += f"{k}={v/max(1, total_batches_global):.6f},"
            train_content += "\n"

            with open(train_log_path, "a") as f:
                f.write(train_content)
            print(train_content, end="")


        val_start_time = time.time()
        val_meter = {
            "acc": 0.0,
            "psnr": 0.0,
            "g_loss": 0.0,
            "g_loss_on_encoder": 0.0,
            "g_loss_on_decoder": 0.0,
        }
        val_batches_local = 0

        network.encoder_decoder.eval()
        with torch.no_grad():
            pbar_val = tqdm(val_loader, total=len(val_loader), desc=f"Val   Epoch {epoch}", disable=not is_main_process())
            for batch_images in pbar_val:
                images = batch_images.to(device, non_blocking=True)
                B = images.size(0)
                messages = torch.randint(0, 2, (B, message_length), device=device, dtype=torch.float32)

                result, _ = network.validation(images, messages)
                for k in val_meter.items():
                    pass  

                for k in val_meter:
                    if k in result:
                        val_meter[k] += float(result[k])
                val_batches_local += 1

        if world_size > 1:
            val_meter = all_reduce_meter_sums(val_meter, device)
            total_val_batches_global = all_reduce_sum_scalar(val_batches_local, device)
        else:
            total_val_batches_global = val_batches_local

        if is_main_process():
            val_elapsed = int(time.time() - val_start_time)
            val_content = f"Epoch {epoch} (val) : {val_elapsed}s\n"
            for k, v in val_meter.items():
                val_content += f"{k}={v/max(1, total_val_batches_global):.6f},"
            val_content += "\n"
            with open(val_log_path, "a") as f:
                f.write(val_content)
            print(val_content, end="")


        if is_main_process() and ((epoch + 1) % 5 == 0):
            model_dir = os.path.join(result_dir, "models")
            ensure_dir(model_dir)
            ckpt_path = os.path.join(model_dir, f"EC_{epoch}.pth")
            network.save_model(ckpt_path)
            print(f"Saved: {ckpt_path}")

        new_enc_w = float(network.encoder_weight)

        if is_main_process():
            train_acc  = train_meter['acc']  / max(1, total_batches_global)
            psnr_is_stable = psnr_stable_last_k(psnr_hist, k=5, tol=1.0)

            if (train_acc > 0.999):
                new_enc_w = network.encoder_weight * 10.0

        if is_dist_avail_and_initialized():
            t_w   = torch.tensor([new_enc_w], device=device, dtype=torch.float32)
            dist.broadcast(t_w,   src=0)
            new_enc_w  = float(t_w.item())

        network.encoder_weight = new_enc_w

        if is_dist_avail_and_initialized():
            dist.barrier() 

    if is_main_process():
        print("\nTraining finished.\n")

    if is_dist_avail_and_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()