import logging

import torch
import torch.optim as optim

from robustbench.data import load_imagenetc, load_cifar10c, load_cifar100c
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
from robustbench.utils import clean_accuracy as accuracy

import tent
import cotta
import vida
import vida_mae
import mimic

from conf import cfg, load_cfg_fom_args

import vit_patch                    # <-- import FIRST

import torch.nn as nn
import operators
from collections import OrderedDict
import numpy as np
import random
from daisy_layer_tpu import DaisyTorch

from timm.models import load_checkpoint
from pathlib import Path
import os

logger = logging.getLogger(__name__)



def evaluate(description):
    args = load_cfg_fom_args(description)
    # configure model

    # ------------------------------------------------------------------
    # 1)  Delete stale checkpoints in  ckpt/<dataset>/corruptions/*.pt
    # ------------------------------------------------------------------
    corrupt_dir = Path(cfg.CKPT_DIR) / cfg.CORRUPTION.DATASET / "corruptions"
    if corrupt_dir.exists():
        for f in corrupt_dir.glob("*.pt"):
            print(f"[MIMIC] removing stale checkpoint: {f}")
            f.unlink()  # delete the file
    # (Optional) remove empty dir:
    #     if not any(corrupt_dir.iterdir()):
    #         corrupt_dir.rmdir()

    # ------------------------------------------------------------------
    # 2)  Prevent RobustBench from re-downloading anything
    # ------------------------------------------------------------------
    os.environ["ROBUSTBENCH_SKIP_DOWNLOAD"] = "1"

    base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
                       cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()

    if cfg.MODEL.ADAPTATION == "source":
        logger.info("test-time adaptation: NONE")
        model = setup_source(args, base_model)
    if cfg.MODEL.ADAPTATION == "tent":
        logger.info("test-time adaptation: TENT")
        model = setup_tent(base_model)
    if cfg.MODEL.ADAPTATION == "cotta":
        logger.info("test-time adaptation: CoTTA")
        model = setup_cotta(base_model)
    if cfg.MODEL.ADAPTATION == "vida":
        logger.info("test-time adaptation: ViDA")
        model = setup_vida(args, base_model)
    if cfg.MODEL.ADAPTATION == "vida_mae":
        logger.info("test-time adaptation: ViDA_MAE")
        model = setup_vida_mae(args, base_model)
    if cfg.MODEL.ADAPTATION == "mimic":
        logger.info("test-time adaptation: MIMIC")
        if args.data_set in ["cifar10", "cifar100"]:
            width = height = 384
        elif args.data_set == "imagenet":
            width = height = 224
        else:
            print("data set name error")
            exit()
        model = setup_mimic(args, base_model,width, height)

    # evaluate on each severity and type of corruption in turn
    # ------------------------------------------------------------------
    # ❷  pick the proper corruption loader
    # ------------------------------------------------------------------
    def get_corrupted_split(dataset_name, severity, corruption):
        if dataset_name == "imagenet":
            return load_imagenetc(cfg.CORRUPTION.NUM_EX, severity,
                                  cfg.DATA_DIR, False, [corruption])
        elif dataset_name == "cifar10":
            return load_cifar10c(cfg.CORRUPTION.NUM_EX, severity,
                                 cfg.DATA_DIR, False, [corruption])
        elif dataset_name == "cifar100":
            return load_cifar100c(cfg.CORRUPTION.NUM_EX, severity,
                                  cfg.DATA_DIR, False, [corruption])
        else:
            raise ValueError(f"Unknown dataset {dataset_name}")

    # ------------------------------------------------------------------
    # ❸  evaluate
    # ------------------------------------------------------------------
    width = height = 384 if args.data_set in {"cifar10", "cifar100"} else 224
    All_error = []
    for severity in cfg.CORRUPTION.SEVERITY:
        for i_c, corruption_type in enumerate(cfg.CORRUPTION.TYPE):

            # episodic reset logic (unchanged)
            try:
                if i_c == 0:
                    model.reset()
                else:
                    logger.warning("not resetting model")
            except:
                logger.warning("not resetting model")

            # LOAD the right corruption split
            x_test, y_test = get_corrupted_split(
                args.data_set, severity, corruption_type)

            # Upsample CIFAR tensors once (ImageNet already 224)
            if args.data_set in {"cifar10", "cifar100"}:
                x_test = torch.nn.functional.interpolate(
                    x_test, size=(height, width),
                    mode='bilinear', align_corners=False
                )
            
            #while (True):
            #    pass

            #x_test, y_test = x_test.cuda(), y_test.cuda()
            acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE,device = 'cuda')
            err = 1. - acc
            All_error.append(err)
            logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")


def setup_source(args, model):

    # --- optional checkpoint override -----------------
    if args.checkpoint:
        if args.checkpoint.endswith(".npz"):
            # timm converts the NPZ → state-dict internally
            load_checkpoint(model, args.checkpoint, strict=False)
            print(f"[MIMIC] loaded NPZ checkpoint {args.checkpoint}")
        else:
            state = torch.load(args.checkpoint, map_location='cpu')
            if 'model' in state:          # Lightning/EMA wrappers
                state = state['model']
            if next(iter(state)).startswith('module.'):
                state = {k[7:]: v for k, v in state.items()}
            model.load_state_dict(state, strict=False)
            print(f"[MIMIC] loaded Torch checkpoint {args.checkpoint}")
    # --------------------------------------------------
    """Set up the baseline source model without adaptation."""
    model.eval()
    logger.info(f"model for evaluation: %s", model)
    return model

def setup_tent(model):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """
    model = tent.configure_model(model)
    params, param_names = tent.collect_params(model)
    optimizer = setup_optimizer(params)
    tent_model = tent.Tent(model, optimizer,
                           steps=cfg.OPTIM.STEPS,
                           episodic=cfg.MODEL.EPISODIC)
    logger.info(f"model for adaptation: %s", model)
    logger.info(f"params for adaptation: %s", param_names)
    logger.info(f"optimizer for adaptation: %s", optimizer)
    return tent_model


def setup_optimizer(params):
    """Set up optimizer for tent adaptation.

    Tent needs an optimizer for test-time entropy minimization.
    In principle, tent could make use of any gradient optimizer.
    In practice, we advise choosing Adam or SGD+momentum.
    For optimization settings, we advise to use the settings from the end of
    trainig, if known, or start with a low learning rate (like 0.001) if not.

    For best results, try tuning the learning rate and batch size.
    """
    if cfg.OPTIM.METHOD == 'Adam':
        return optim.Adam(params,
                    lr=cfg.OPTIM.LR,
                    betas=(cfg.OPTIM.BETA, 0.999),
                    weight_decay=cfg.OPTIM.WD)
    elif cfg.OPTIM.METHOD == 'SGD':
        return optim.SGD(params,
                   lr=cfg.OPTIM.LR,
                   momentum=0.9,
                   dampening=0,
                   weight_decay=cfg.OPTIM.WD,
                   nesterov=True)
    else:
        raise NotImplementedError

def setup_cotta(model):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """
    model = cotta.configure_model(model)
    params, param_names = cotta.collect_params(model)
    optimizer = setup_optimizer(params)
    cotta_model = cotta.CoTTA(model, optimizer,
                           steps=cfg.OPTIM.STEPS,
                           episodic=cfg.MODEL.EPISODIC)
    logger.info(f"model for adaptation: %s", model)
    logger.info(f"params for adaptation: %s", param_names)
    logger.info(f"optimizer for adaptation: %s", optimizer)
    return cotta_model

def setup_vida(args, model):
    model = vida.configure_model(model, cfg)
    model_param, vida_param = vida.collect_params(model)
    optimizer = setup_optimizer_vida(model_param, vida_param, cfg.OPTIM.LR, cfg.OPTIM.ViDALR)
    vida_model = vida.ViDA(model, optimizer,
                           steps=cfg.OPTIM.STEPS,
                           episodic=cfg.MODEL.EPISODIC,
                           unc_thr = args.unc_thr,
                           ema = cfg.OPTIM.MT,
                           ema_vida = cfg.OPTIM.MT_ViDA,
                           )
    logger.info(f"model for adaptation: %s", model)
    logger.info(f"optimizer for adaptation: %s", optimizer)
    return vida_model

def setup_vida_mae(args, model):

    head_dim = 768
    if cfg.use_hog:
        nbins = 9
        cell_sz = 8
        hogs = operators.HOGLayerC(
            nbins=nbins,
            pool=cell_sz
        )
        hogs = nn.DataParallel(hogs)  # make parallel
        hogs.cuda()

        # hog_projection
        num_class = int(nbins * 3 * (16 / cell_sz) * (16 / cell_sz))

        hog_ratio = cfg.hog_ratio
    elif cfg.use_daisy:

        hogs = DaisyTorch(
            step=16, radius=15,
            rings=3, histograms=8, orientations=8,
            normalization='daisy', fp16=False, return_numpy=False
        )

        # hogs.cuda()
        hogs = nn.DataParallel(hogs)  # make parallel
        hogs.cuda()

        # DAISY dimension: (rings*hist + 1)*orientations = (3*8+1)*8 = 200
        num_class = (3 * 8 + 1) * 8
        hog_ratio = cfg.daisy_ratio
    else:
        num_class = 224*224*3 # should be image size
        hogs = None
        hog_ratio = 0.5

    # NEW — create and register the learnable mask token
    #model.mask_token = nn.Parameter(
    #    torch.zeros(1, 1, head_dim, device='cuda'),
    #    requires_grad=True
    #)

    # now wrap; the token is replicated to every GPU
    #model = torch.nn.DataParallel(model)

    model = vida_mae.configure_model(model, cfg, head_dim, num_class)

    model_param, vida_mae_param = vida_mae.collect_params(model)

    optimizer = setup_optimizer_vida(model_param, vida_mae_param, cfg.OPTIM.LR, cfg.OPTIM.ViDALR)


    vida_mae_model = vida_mae.VIDA_MAE(model, optimizer,
                           hogs = hogs,
                           mask_token = model.module.mask_token,
                           hog_ratio = hog_ratio,
                           block_size=cfg.block_size,
                           mask_method=cfg.mask_method,
                           mask_ratio=cfg.mask_ratio,
                           steps=cfg.OPTIM.STEPS,
                           episodic=cfg.MODEL.EPISODIC,
                           ema = cfg.OPTIM.MT,
                           ema_vida = cfg.OPTIM.MT_ViDA,
                           unc_thr=args.unc_thr)
    #logger.info(f"params for adaptation: %s", param_names)
    logger.info(f"model for adaptation: %s", model)
    logger.info(f"optimizer for adaptation: %s", optimizer)
    return vida_mae_model


def setup_mimic(args, model, width, height):

    head_dim = 768
    mim_loss_ratio = cfg.mim_loss_ratio

    # --- optional checkpoint override -----------------
    if args.checkpoint:
        if args.checkpoint.endswith(".npz"):
            # timm converts the NPZ → state-dict internally
            load_checkpoint(model, args.checkpoint, strict=False)
            print(f"[MIMIC] loaded NPZ checkpoint {args.checkpoint}")
        else:
            state = torch.load(args.checkpoint, map_location='cpu')
            if 'model' in state:          # Lightning/EMA wrappers
                state = state['model']
            if next(iter(state)).startswith('module.'):
                state = {k[7:]: v for k, v in state.items()}
            model.load_state_dict(state, strict=False)
            print(f"[MIMIC] loaded Torch checkpoint {args.checkpoint}")
    # --------------------------------------------------

    model = mimic.configure_model(model, head_dim)

    model_param, mask_param, atten_param  = mimic.collect_params(model)

    optimizer = setup_optimizer_mimic(model_param, mask_param, atten_param, cfg.OPTIM.LR, cfg.OPTIM.ViDALR, cfg.OPTIM.proj_LR)


    mimic_model = mimic.MIMIC(model, optimizer,
                           mask_token = model.module.mask_token,
                           mim_loss_ratio = mim_loss_ratio,
                           block_size=cfg.block_size,
                           mask_method=cfg.mask_method,
                           mask_ratio=cfg.mask_ratio,
                           steps=cfg.OPTIM.STEPS,
                           episodic=cfg.MODEL.EPISODIC,
                           ema = cfg.OPTIM.MT,
                           ema_atten = cfg.OPTIM.MT_ViDA,
                           width=width,
                           height=height
                              )
    #logger.info(f"params for adaptation: %s", param_names)
    logger.info(f"model for adaptation: %s", model)
    logger.info(f"optimizer for adaptation: %s", optimizer)
    return mimic_model

def setup_optimizer_vida(params, params_vida, model_lr, vida_lr):
    if cfg.OPTIM.METHOD == 'Adam':
        return optim.Adam([{"params": params, "lr": model_lr},
                                  {"params": params_vida, "lr": vida_lr}],
                                 lr=1e-5, betas=(cfg.OPTIM.BETA, 0.999),weight_decay=cfg.OPTIM.WD)

    elif cfg.OPTIM.METHOD == 'SGD':
        return optim.SGD([{"params": params, "lr": model_lr},
                                  {"params": params_vida, "lr": vida_lr}],
                                    momentum=cfg.OPTIM.MOMENTUM,dampening=cfg.OPTIM.DAMPENING,
                                    nesterov=cfg.OPTIM.NESTEROV,
                                 lr=1e-5,weight_decay=cfg.OPTIM.WD)
    else:
        raise NotImplementedError

def setup_optimizer_mimic(params, params_atten, params_proj, model_lr, atten_lr, proj_lr):
    if cfg.OPTIM.METHOD == 'Adam':
        return optim.Adam([{"params": params, "lr": model_lr},
            {"params": params_atten, "lr": atten_lr},
            {"params": params_proj, "lr": proj_lr}],
            lr=1e-5, betas=(cfg.OPTIM.BETA, 0.999),weight_decay=cfg.OPTIM.WD)
    elif cfg.OPTIM.METHOD == 'SGD':
        return optim.SGD([{"params": params, "lr": model_lr},
            {"params": params_atten, "lr": atten_lr},
            {"params": params_proj, "lr": proj_lr}],
            momentum=cfg.OPTIM.MOMENTUM,dampening=cfg.OPTIM.DAMPENING, nesterov=cfg.OPTIM.NESTEROV,lr=1e-5,weight_decay=cfg.OPTIM.WD)
    else:
        raise NotImplementedError

if __name__ == '__main__':
    evaluate('"Imagenet-C evaluation.')
