import argparse
import torch
import time
import datetime
import os
import numpy as np
import pandas as pd
import csv
from typing import Dict, List, Tuple, Optional

from utils.train_util import setup_logger, set_random_seed, AverageMeter, load_clip_to_cpu
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from datasets.train.imagenet import ImageNetDataset
from datasets.eval.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from clip_w_local.simple_tokenizer import SimpleTokenizer as _Tokenizer
from trainers import build_optimizer, build_lr_scheduler, CustomCLIP
from utils.losses import compute_accuracy, entropy_select_topk
from tqdm import tqdm
from torch.nn import functional as F
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json
from configs.implemention import get_cfg_default
from yacs.config import CfgNode as CN


class LoCoOpExperiment:
    """Class to manage LoCoOp experiments"""
    
    def __init__(self, cfg: CN, args: argparse.Namespace):
        self.cfg = cfg
        self.args = args
        self.device = self._setup_device()
        self.model = None
        self.train_loader = None
        self.optimizer = None
        self.scheduler = None
        
    def _setup_device(self) -> torch.device:
        """Setup device"""
        if torch.cuda.is_available() and self.cfg.USE_CUDA:
            torch.backends.cudnn.benchmark = True
            return torch.device("cuda")
        return torch.device("cpu")
    
    def _validate_config(self) -> None:
        """Validate configuration"""
        if not hasattr(self.cfg, 'DATASET') or not hasattr(self.cfg.DATASET, 'ROOT'):
            raise ValueError("DATASET.ROOT must be specified in config")
        
        if not os.path.exists(self.cfg.DATASET.ROOT):
            raise ValueError(f"Dataset path does not exist: {self.cfg.DATASET.ROOT}")
            
        if not hasattr(self.cfg, 'OUTPUT_DIR'):
            raise ValueError("OUTPUT_DIR must be specified in config")
    
    def _create_train_transform(self) -> Compose:
        """Create training data transforms"""
        if hasattr(self.cfg, 'INPUT') and hasattr(self.cfg.INPUT, 'SIZE'):
            resize_size = tuple(self.cfg.INPUT.SIZE)
        else:
            resize_size = (224, 224)
            
        train_transform = []
        if "random_resized_crop" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomResizedCrop(resize_size))
        if "random_flip" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomHorizontalFlip())
        train_transform.append(ToTensor())
        if "normalize" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(Normalize(mean=self.cfg.INPUT.PIXEL_MEAN, std=self.cfg.INPUT.PIXEL_STD))
        
        return Compose(train_transform)
    
    def setup_data(self) -> None:
        """Setup dataset and data loaders"""
        print("Setting up data loaders...")
        
        # Prepare training dataset
        num_shots = getattr(self.cfg.DATASET, 'NUM_SHOTS', 1)
        seed = getattr(self.cfg, 'SEED', 1)
        
        train_transform = self._create_train_transform()
        
        self.train_dataset = ImageNetDataset(
            root=self.cfg.DATASET.ROOT,
            split="train",
            num_shots=num_shots,
            seed=seed,
            transform=train_transform
        )
        
        # Prepare data loader
        batch_size = self.cfg.DATALOADER.TRAIN_X.BATCH_SIZE if hasattr(self.cfg.DATALOADER, 'TRAIN_X') else 32
        num_workers = getattr(self.cfg.DATALOADER, 'NUM_WORKERS', 8)
        
        print(f"Using {num_workers} workers for data loading")
        
        self.train_loader = DataLoader(
            self.train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=True if num_workers > 0 else False,
            prefetch_factor=2 if num_workers > 0 else None
        )
        
        print(f"Train dataset size: {len(self.train_dataset)}")
    
    def setup_model(self) -> None:
        """Setup model and optimizer"""
        print("Setting up model and optimizer...")
        
        # Load CLIP model
        print(f"Loading CLIP (backbone: {self.cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(self.cfg)
        
        if self.cfg.TRAINER.LOCOOP.PREC in ["fp32", "amp"]:
            clip_model.float()
        
        # Build custom CLIP model
        classnames = self.train_dataset.classnames
        print("Building custom CLIP")
        self.model = CustomCLIP(self.cfg, classnames, clip_model)
        
        # Disable gradients except for prompt learner
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        
        self.model.to(self.device)
        
        # Build optimizer and scheduler
        self.optimizer = build_optimizer(self.model.prompt_learner, self.cfg.OPTIM)
        self.scheduler = build_lr_scheduler(self.optimizer, self.cfg.OPTIM)
        
        print("Model setup completed")
    
    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_acc = 0
        total_samples = 0
        batch_time = AverageMeter()
        data_time = AverageMeter()
        num_batches = len(self.train_loader)
        
        lambda_value = self.cfg.lambda_value
        top_k = self.cfg.topk
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        end = time.time()
        
        for batch_idx, batch in enumerate(self.train_loader):
            data_time.update(time.time() - end)
            images, labels, _ = batch
            images = images.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            output, output_local = self.model(images)

            # Calculate CoOp loss
            loss_id = F.cross_entropy(output, labels)

            if self.args.method == 'locoop':
                # Calculate OOD regularization loss
                batch_size, num_of_local_feature = output_local.shape[0], output_local.shape[1]
                output_local = output_local.view(batch_size * num_of_local_feature, -1)
                loss_en = -entropy_select_topk(output_local, top_k, labels, num_of_local_feature)
                loss = loss_id + lambda_value * loss_en
            else:  # coop
                loss_en = torch.tensor(0.0, device=loss_id.device)
                loss = loss_id

            loss.backward()
            self.optimizer.step()

            # Calculate accuracy
            acc = compute_accuracy(output, labels)[0].item()
            total_loss += loss.item() * images.size(0)
            total_acc += acc * images.size(0) / 100.0
            total_samples += images.size(0)

            batch_time.update(time.time() - end)
            end = time.time()

            # Calculate ETA
            nb_remain = (num_batches - batch_idx - 1) + (max_epoch - epoch - 1) * num_batches
            eta_seconds = batch_time.avg * nb_remain
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))

            # Display progress
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == num_batches:
                print(
                    f"epoch [{epoch + 1}/{max_epoch}] "
                    f"batch [{batch_idx + 1}/{num_batches}] "
                    f"time {batch_time.val:.3f} ({batch_time.avg:.3f}) "
                    f"data {data_time.val:.3f} ({data_time.avg:.3f}) "
                    f"loss {loss.item():.4f} "
                    f"loss_id {loss_id.item():.4f} "
                    f"loss_en {loss_en.item():.4f} "
                    f"acc {acc:.2f} "
                    f"lr {self.optimizer.param_groups[0]['lr']:.4e} "
                    f"eta {eta}"
                )

        self.scheduler.step()
        
        # Calculate epoch averages
        avg_loss = total_loss / total_samples
        avg_acc = 100 * total_acc / total_samples
        
        return {
            'loss': avg_loss,
            'accuracy': avg_acc
        }
    
    def train(self) -> None:
        """Train for all epochs"""
        print("Starting training...")
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        for epoch in range(max_epoch):
            metrics = self.train_epoch(epoch)
            print(f"Epoch [{epoch+1}/{max_epoch}] Avg Loss: {metrics['loss']:.4f} Avg Acc: {metrics['accuracy']:.2f}")
            
            # # Save model
            # if (epoch + 1) % 10 == 0 or (epoch + 1) == max_epoch:
            #     save_path = f"{self.cfg.OUTPUT_DIR}/model_epoch_{epoch+1}.pth"
            #     torch.save(self.model.prompt_learner.state_dict(), save_path)
            #     print(f"Model saved to {save_path}")
        
        print("Training completed")
    
    def load_model(self, model_path: str) -> None:
        """Load model"""
        print(f"Loading model from {model_path}")
        self.model.prompt_learner.load_state_dict(torch.load(model_path))
        print("Model loaded successfully")
    
    def test_ood_detection(self, model: torch.nn.Module, data_loader: DataLoader, T: float = 1.0) -> np.ndarray:
        """Test OOD detection"""
        to_np = lambda x: x.data.cpu().numpy()
        concat = lambda x: np.concatenate(x, axis=0)

        # glmcm_score = []
        mcm_score = []
        
        for batch_idx, (images, labels, *id_flag) in enumerate(tqdm(data_loader, desc="Testing OOD")):
            images = images.to(self.device)
            model.eval()
            with torch.no_grad():
                output, output_local = model(images)
            
            output /= 100.0
            output_local /= 100.0
            smax_global = to_np(F.softmax(output/T, dim=-1))
            # smax_local = to_np(F.softmax(output_local/T, dim=-1))
            mcm_global_score = -np.max(smax_global, axis=1)
            # mcm_local_score = -np.max(smax_local, axis=(1, 2))
            mcm_score.append(mcm_global_score)
            # glmcm_score.append(mcm_global_score + mcm_local_score)

        # return (concat(mcm_score)[:len(data_loader.dataset)].copy(), concat(glmcm_score)[:len(data_loader.dataset)].copy())
        return concat(mcm_score)[:len(data_loader.dataset)].copy()
    
    def evaluate(self) -> None:
        """Run evaluation"""
        print("Starting evaluation...")
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Prepare data loader
        _, preprocess = clip_w_local.load(self.cfg.MODEL.BACKBONE.NAME)
        self.args.in_dataset = "imagenet"
        self.args.batch_size = 512
        
        id_data_loader = set_test_loader(self.args, "imagenet", preprocess)
        
        # Calculate in-distribution scores
        in_score_mcm = self.test_ood_detection(self.model, id_data_loader, 1)
        
        # Lists for evaluation
        auroc_list_mcm, fpr_list_mcm = [], []
        # auroc_list_gl, fpr_list_gl = [], []
        results_data = []
        
        # Evaluate out-of-distribution datasets
        out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
        
        scores_dict: Dict[str, Dict[str, np.ndarray]] = {}
        scores_dict["MCM"] = {}
        # scores_dict["GL-MCM"] = {}
        scores_dict["MCM"]["ImageNet"] = in_score_mcm
        # scores_dict["GL-MCM"]["ImageNet"] = in_score_gl
  
        
        for out_dataset in out_datasets:
            print(f"Evaluating OOD dataset: {out_dataset}")
            ood_loader = set_test_loader(self.args, out_dataset, preprocess)
            out_score_mcm = self.test_ood_detection(self.model, ood_loader, 1)

            # Evaluate MCM score
            print("MCM score")
            mcm_results = get_and_print_results(
                self.args, in_score_mcm, out_score_mcm,
                auroc_list_mcm, fpr_list_mcm
            )
            scores_dict["MCM"][out_dataset] = out_score_mcm

            # Evaluate GL-MCM score
            # print("GL-MCM score")
            # glmcm_results = get_and_print_results(
            #     self.args, in_score_gl, out_score_gl,
            #     auroc_list_gl, fpr_list_gl
            # )
            # scores_dict["GL-MCM"][out_dataset] = out_score_gl
            
            # Save results
            results_data = add_results(results_data, mcm_results, out_dataset)

            # Display averages
            # print(f"MCM avg. FPR:{np.mean(fpr_list_mcm):.4f}, AUROC:{np.mean(auroc_list_mcm):.4f}")
            # print(f"GL-MCM avg. FPR:{np.mean(fpr_list_gl):.4f}, AUROC:{np.mean(auroc_list_gl):.4f}")

        # add overall results to results_data. 正し今results_dataのtypeはlist of dict
        results_data = add_overall_results(results_data, auroc_list_mcm, fpr_list_mcm)

        # Save scores to .npz
        np.savez(f"{self.args.output_dir}/scores.npz", **scores_dict)

        # Save results to JSON
        save_results_to_json(results_data, self.args.output_dir, "results.json")
        print("Evaluation completed")


def print_args(args: argparse.Namespace, cfg: CN) -> None:
    """Print arguments and configuration"""
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)


def reset_cfg(cfg: CN, args: argparse.Namespace) -> None:
    """Reset configuration"""
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head

    if args.lambda_value:
        cfg.lambda_value = args.lambda_value

    if args.topk:
        cfg.topk = args.topk


def extend_cfg(cfg: CN) -> None:
    """Extend configuration"""
    cfg.TRAINER.LOCOOP = CN()
    cfg.TRAINER.LOCOOP.N_CTX = 16  # number of context vectors
    cfg.TRAINER.LOCOOP.CSC = False  # class-specific context
    cfg.TRAINER.LOCOOP.CTX_INIT = ""  # initialization words
    cfg.TRAINER.LOCOOP.PREC = "fp16"  # fp16, fp32, amp
    cfg.TRAINER.LOCOOP.CLASS_TOKEN_POSITION = "end"  # 'middle' or 'end' or 'front'

    cfg.DATASET.SUBSAMPLE_CLASSES = "all"  # all, base or new


def setup_cfg(args: argparse.Namespace) -> CN:
    """Setup configuration"""
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    cfg.merge_from_list(args.opts)

    cfg.freeze()
    return cfg


def main(args: argparse.Namespace) -> None:
    """Main function"""
    # Prepare configuration
    cfg = setup_cfg(args)
    
    # Set seed
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    
    # Setup logger
    setup_logger(cfg.OUTPUT_DIR)
    
    # Initialize experiment class
    experiment = LoCoOpExperiment(cfg, args)
    
    # Validate configuration
    experiment._validate_config()
    
    # Prepare data
    experiment.setup_data()
    
    # Prepare model
    experiment.setup_model()
    
    # Load model for evaluation only
    if args.eval_only:
        if not args.model_dir:
            raise ValueError("--model-dir must be specified for eval-only mode")
        experiment.load_model(args.model_dir)
    
    # Run training
    if not args.eval_only:
        experiment.train()
    
    # Run evaluation
    experiment.evaluate()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="", help="path to dataset")
    parser.add_argument("--output-dir", type=str, default="", help="output directory")
    parser.add_argument(
        "--resume",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument(
        "--seed", type=int, default=-1, help="only positive value enables a fixed seed"
    )
    parser.add_argument(
        "--config-file", type=str, default="", help="path to config file"
    )
    parser.add_argument(
        "--dataset-config-file",
        type=str,
        default="",
        help="path to config file for dataset setup",
    )
    parser.add_argument("--trainer", type=str, default="", help="name of trainer")
    parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument("--eval-only", action="store_true", help="evaluation only")
    parser.add_argument(
        "--model-dir",
        type=str,
        default="",
        help="load model from this directory for eval-only mode",
    )
    parser.add_argument(
        "--load-epoch", type=int, help="load model weights at this epoch for evaluation"
    )
    parser.add_argument(
        "--no-train", action="store_true", help="do not call trainer.train()"
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="modify config options using the command-line",
    )
    # augment for LoCoOp
    parser.add_argument('--lambda_value', type=float, default=0.25,
                        help='weight for regulization loss')
    parser.add_argument('--topk', type=int, default=200,
                        help='topk for extracted OOD regions')
    # augment for test-time OOD detection
    parser.add_argument('--T', type=float, default=1,
                        help='temperature for softmax')
    parser.add_argument('--method', type=str, default='locoop', choices=['locoop', 'coop'], help='method type: locoop or coop')
    parser.add_argument('--sample-size', type=int, default=500, help='sample size for test set')
    args = parser.parse_args()
    main(args) 