import os
import torch
import logging
import torch.nn as nn
import torch.nn.functional as F

# SET OPTIMIZER
from torch.optim import Adam

# SET DATALOADER SAMPLERS
import torch.distributed as dist
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


from src.utils.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from src.models.model_config import JCGConfig

# SET Model
from src.models.jcgel.jcgel_resnet_cls import JCGResNet_cls
from src.models.jcgel.jcgel_imbalance_cls import JCG_Imbalance_CLS

# Load Model
from src.utils.load_weight import load_best_model
# SET dataloader
from src.dataloaders.imbalance.imbalance_factory import color_rot_mnist_bias, color_rot_mnist_longtail
from src.dataloaders.classification.cls_factory import get_cls_factory_dataloader
from src.dataloaders.classification.aircraft import get_aircraft_loaders
from src.dataloaders.classification.eurosat import get_eurosat_loaders


# Set file dir
from src.file.file_cls import make_finetuning_files

# set info
from src.info.info_cls import write_info


Models = {
    'cgeresnet': JCGResNet_cls,
    'cgeconv': JCG_Imbalance_CLS
}

def is_ddp():
    return dist.is_available() and dist.is_initialized()
def is_main_process(args=None):
    return dist.get_rank() == 0

def attach_bn_guards(model):
    def _guard(module, inputs):
        (x,) = inputs
        if x.dim() != 4:
            raise RuntimeError(f"Expected 4D NCHW, got {x.dim()}D before {module.__class__.__name__}")
        if not x.is_contiguous():
            raise RuntimeError(
                f"Non-contiguous before {module.__class__.__name__}: "
                f"shape={tuple(x.shape)}, stride={x.stride()}, dtype={x.dtype}"
            )
        N, C, H, W = x.shape
        if N*H*W < 2:
            raise RuntimeError(
                f"Too few elems per channel for BN: N={N},C={C},H={H},W={W}"
            )
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.register_forward_pre_hook(_guard)

_ORIG_F_BATCH_NORM = F.batch_norm
def install_functional_bn_guard():
    def _safe_bn(input, *args, **kwargs):
        x = input
        if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
            x = x.float()
        if not x.is_contiguous():
            x = x.contiguous()
        return _ORIG_F_BATCH_NORM(x, *args, **kwargs)
    F.batch_norm = _safe_bn

def attach_bn_safety_hooks(model):
    def _bn_safe_hook(module, inputs):
        (x,) = inputs
        if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
            x = x.float()
        if not x.is_contiguous():
            x = x.contiguous()
        x = x.contiguous(memory_format=torch.contiguous_format)
        return (x,)

    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.register_forward_pre_hook(_bn_safe_hook)



class Trainer:
    def __init__(self, args):
        set_seed(args)
        self.args = args

        if self.args.task == 'longtail':
            self.trainset, self.testset = color_rot_mnist_longtail()
        elif self.args.task == 'bias':
            self.trainset, self.testset = color_rot_mnist_bias()
        elif self.args.task == 'cls':
            if self.args.dataset == 'eurosat':
                self.trainset, self.testset = get_eurosat_loaders()
            elif self.args.dataset == 'aircraft':
                self.trainset, self.testset = get_aircraft_loaders()
            else:
                self.trainset, self.testset = get_cls_factory_dataloader(name=self.args.dataset)

        self.args.steps = 500000
        self.args.dense_dim = [256, 256]
        self.args.n_gpu = 0 if self.args.no_cuda else torch.cuda.device_count()
        self.args.train_batch_size = self.args.per_gpu_train_batch_size * max(1, self.args.n_gpu)

        dataset_size = len(self.trainset)
        dataset_size_per_epoch = dataset_size // self.args.train_batch_size
        t_total = dataset_size_per_epoch * self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps
        if self.args.max_steps == 0:
            self.args.num_epoch = self.args.num_epoch
        else:
            if self.args.max.steps % dataset_size_per_epoch == 0:
                self.args.num_epoch = self.args.max_steps // dataset_size_per_epoch
            else:
                self.args.num_epoch = self.args.max_steps // dataset_size_per_epoch + 1
        # self.args.num_epoch = self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps // dataset_size_per_epoch + 1
        self.args.logging_steps = t_total // self.args.num_epoch
        self.args.t_total = t_total


        # Set model
        self.config = JCGConfig(self.args)
        self.model = Models[self.args.model_type](self.config)
        self.model.init_weights()
        self.save_file, self.run_file, self.output_dir = make_finetuning_files(self.args)


        # ONLY FOR EVALUATION
        if args.do_train != True and (args.do_eval or args.do_analysis):
            sub_model, path = load_best_model(args, self.save_file)
            if os.path.exists(path):
                self.model.load_state_dict(sub_model, strict=False)



        self.train_sampler, self.train_dataloader = None, None
        self.test_sampler, self.test_dataloader = None, None
        self.optimizer, self.scheduler = None, None
        if self.args.model_type == "factorvae":
            self.disc_optimizer, self.disc_scheduler = None, None
        self.global_step = 0

    def setting(self):
        set_seed(self.args)

        self.train_sampler = (
            RandomSampler(self.trainset)
            if self.args.local_rank == -1
            else DistributedSampler(self.trainset)
        )
        self.train_dataloader = DataLoader(
            self.trainset,
            sampler=self.train_sampler,
            batch_size=self.args.per_gpu_train_batch_size,
            drop_last=True,
            num_workers=4,
            pin_memory=True,
        )

        is_ddp = dist.is_available() and dist.is_initialized()
        self.test_sampler = (
            SequentialSampler(self.testset)
            if not is_ddp
            else DistributedSampler(self.testset, shuffle=False, drop_last=False)
        )
        self.test_dataloader = DataLoader(
            self.testset,
            sampler=self.test_sampler,
            batch_size=self.args.test_batch_size,
            drop_last=False,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True
        )

        self.optimizer = Adam(self.model.parameters(),
                              lr=self.args.lr_rate,
                              betas=(0.9, 0.999),
                              weight_decay=self.args.weight_decay,
                              )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.num_epoch)

        # Check if saved optimizer or scheduler states exist
        if os.path.isfile(
            os.path.join(self.save_file, "optimizer.pt")
        ) and os.path.isfile(
            os.path.join(self.save_file, "scheduler.pt")
        ):
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "optimizer.pt"
                    )
                )
            )
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "scheduler.pt"
                    )
                )
            )


        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            torch.cuda.set_device(self.args.local_rank) # ✅ set the default device for current process
            self.model = self.model.to(memory_format = torch.contiguous_format)
            self.model.to(self.args.local_rank)
            if dist.is_available() and dist.is_initialized():
                self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=False,
            )
            attach_bn_safety_hooks(self.model)
            # attach_bn_guards(self.model)
        else:
            self.model.to(device)



    def train(self):
        NotImplementedError


    def eval(self):
        NotImplementedError


    def qualitative(self):
        NotImplementedError


    def analysis(self):
        NotImplementedError


    def save_results(self, best_results, last_results):

        if self.args.do_train and self.args.do_eval:
            self.args.results_file = os.path.join(self.output_dir, "results.csv")
        else:
            self.args.results_file = os.path.join(self.output_dir, "eval_only_results.csv")

        write_info(self.args, best_results, last_results)


    def run(self):
        NotImplementedError


















