import os
import numpy as np
import time
from typing import List
import torch
import random
import configparser
import argparse
from train.lora_vit_helper import ViTPEARL, split_qkv
from train.lora_helper import SeparableConvLoRAModelBuilder, ResNet18LoRAModelBuilder, ViTLoRAModelBuilder
from models.base import BasicBlock
from models.separable_conv import SeparableResNet
from models.residual_net import ResNet, BasicBlockResNet
import torch.nn as nn
from dataset.helper import create_sequential_dataloaders, load_dataset
from deprecated import deprecated
from logger import setup_logger
import time



class Args:
    """ Single instantiation of args for global use"""
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(Args, cls).__new__(cls)
            cls._instance.args = None
        return cls._instance

    def set_args(self, args):
        self.args = args

    def get_args(self):
        return self.args


def timer_decorator(func):
    """Decorator for measuring execution time"""
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"Execution time: {end_time - start_time:.2f} seconds")
        return result
    return wrapper


def update_args(args):
    """ Update arguments """
    timestamp = f"{time.strftime('%Y%m%d_%H%M%S')}"
    logger_name = f"{timestamp}.log"
    args.log_dir = os.path.join(args.log_dir, timestamp)
    logger = setup_logger(args.log_dir, logger_name)
    logger.info(args)
    args.logger = logger
    args.device = get_device(args)
    args.n_classes = args.n_classes_per_task * args.n_tasks
    return args


def get_dataloaders():
    """ Return sequential train and test data laoders"""
    args = Args().get_args()
    trainset, testset = load_dataset(args)
    n_classes_per_task = [args.n_classes_per_task] * args.n_tasks
    train_dataloaders, test_dataloaders = create_sequential_dataloaders(trainset, testset,
                                                                        n_classes_per_task,
                                                                        args.n_classes, args.batch_size)
    return train_dataloaders, test_dataloaders


def get_device(args):
    """ Return device type"""
    device = torch.device("cuda" if torch.cuda.is_available()
                            else "mps" if torch.backends.mps.is_available()
                            else "cpu"
                          )
    log_and_print(f"Device: {device}", args.logger, args.verbose)
    return device


def retrieve_model_and_builder():
    """ Return model and lora_builder based on args.model. Only Vit is pre-trained, rest are randomly initialized. """
    args = Args().get_args()
    if args.model == "separable_conv":
        net = SeparableResNet(
        BasicBlock,
        [args.n_classes_per_task],
        factor=args.factor,
        depth=args.depth,
        logger=args.logger,
        device=args.device,
        forward_transfer=args.forward_transfer
    )
        lora_builder = SeparableConvLoRAModelBuilder(args)
    elif args.model == 'resnet':
        net = ResNet(
                        block=BasicBlockResNet,
                        num_blocks=[2, 2, 2, 2],
                        num_classes=args.n_classes_per_task,
                        nf=64,
                        n_tasks=1,
                        args=args,
                        device=args.device
                    )
        lora_builder = ResNet18LoRAModelBuilder(args)
    elif args.model == "vit":
        net = ViTPEARL(args).to(args.device)
        split_qkv(net.image_encoder, args.logger) # Split qkv to q, k, v
        lora_builder = ViTLoRAModelBuilder(args)
    else:
        raise Exception("Unknown model architecture!")

    # Use DataParallel
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    net.to(args.device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    log_and_print(f"Number of trainable parameters: {num_params}", args.logger, args.verbose)

    return net, lora_builder


def read_config_file(config_file):
    """ Read config *.ini file to parse arguments for training / inference"""
    config = configparser.SafeConfigParser()
    config.read(config_file)
    if not config.sections():
        raise Exception("No sections found in the configuration file. Please check the file format.")
    else:
        print(f"Config sections: {config.sections()}")

    return config


def bool_from_string(value):
    """Convert string to boolean."""
    if value.lower() in ('yes', 'true', 't', '1'):
        return True
    elif value.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
        

def set_seed(seed: int):
    """ Sets random seeds for reproducibility"""
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def log_and_print(message, logger, verbose):
    """ Log and print the messages if verbose is set"""
    if verbose:
        print(message)
    logger.info(message)


def save_task_performance(
    data: List[List[float]],
    filename: str = "task_performance",
    output_dir: str = "./results/",
):
    """ Save task-wise performance to *.txt file"""
    n = len(data)
    matrix = [
        [data[i][j] if j < len(data[i]) else 0 for j in range(n)] for i in range(n)
    ]

    os.makedirs(output_dir, exist_ok=True)

    filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, filename)

    with open(file_path, "w") as f:
        for row in matrix:
            f.write(" ".join(map(str, row)) + "\n")


def save_task_probabilities(
    data: List[float],
    filename: str = "task_probabilities",
    output_dir: str = "./results/",
) -> None:
    """ Save task-wise probabilities to *.txt file"""
    total = sum(data)
    probabilities = [x / total for x in data]

    os.makedirs(output_dir, exist_ok=True)

    filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, filename)

    with open(file_path, "w") as f:
        for prob in probabilities:
            f.write(f"{prob} ")


def save_stability_plasticity(
    stability: float,
    plasticity: float,
    filename: str = "stability_plasticity_tradeoff",
    output_dir: str = "./results/",
) -> None:
    """ Save stability-plasticity trade-off to *.txt file"""
    trade_off = (2 * stability * plasticity) / (stability + plasticity)
    combined = [stability, plasticity, trade_off]

    os.makedirs(output_dir, exist_ok=True)

    filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, filename)

    with open(file_path, "w") as f:
        for value in combined:
            f.write(f"{value} ")


def save_confusion_matrix(
    confusion_matrix: List[List[int]],
    filename: str = "confusion_matrix",
    output_dir: str = "./results/",
) -> None:
    """ Save cnfusion matrix to *.txt file"""
    filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, filename)

    with open(file_path, "w") as f:
        for row in confusion_matrix:
            row_str = " ".join(map(str, row))
            f.write(row_str + "\n")


def save_calibration(
    probabilities: np.ndarray,
    predictions: np.ndarray,
    labels: np.ndarray,
    n_bins: int = 10,
    filename: str = "calibration",
    output_dir: str = "./results/",
) -> None:
    """ Save calibration data to *.txt file"""
    bins = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bins[:-1]
    bin_uppers = bins[1:]

    accuracies = np.zeros(n_bins)
    confidences = np.zeros(n_bins)
    bin_counts = np.zeros(n_bins)

    for i, (bin_lower, bin_upper) in enumerate(zip(bin_lowers, bin_uppers)):
        in_bin = (probabilities > bin_lower) & (probabilities <= bin_upper)
        bin_count = np.sum(in_bin)
        if bin_count > 0:
            accuracies[i] = np.sum(labels[in_bin] == predictions[in_bin]) / bin_count
            confidences[i] = np.mean(probabilities[in_bin])
            bin_counts[i] = bin_count

    # Calculate Expected Calibration Error (ECE)
    ece = np.sum(np.abs(accuracies - confidences) * (bin_counts / len(labels)))

    os.makedirs(output_dir, exist_ok=True)

    unique_filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, unique_filename)

    with open(file_path, "w") as f:
        f.write(" ".join(map(str, accuracies)) + "\n")
        f.write(" ".join(map(str, confidences)) + "\n")
        f.write(f"{ece:.4f}\n")


def save_sequential_performance(
    data: List[float],
    filename: str = "sequential_performance",
    output_dir: str = "./results/",
) -> None:
    """ Save sequential task-wise performance to *.txt file"""
    os.makedirs(output_dir, exist_ok=True)

    filename = add_generated_id(filename) + ".txt"
    file_path = os.path.join(output_dir, filename)

    with open(file_path, "w") as f:
        f.write(" ".join(map(str, data)) + "\n")


def add_generated_id(original_string: str) -> str:
    timestamp_id = str(int(time.time()))
    return f"{original_string}_{timestamp_id}"


@deprecated(reason="This function will be removed in future versions.")
def freeze_layers(model, task_id):
    """ Freeze certain layers for separable convolution. """
    task_id_str = str(task_id)

    layers_to_unfreeze = {
        "SeparableConv2d1.depthwise.",
        "SeparableConv2d2.depthwise.",
        "conv1.depthwise.",
        "fcs.",
        "SeparableConv2d1.pointwise.",
        "SeparableConv2d2.pointwise.",
        "conv1.pointwise.",
        "pre_bn.",
        "bns1.",
        "bns2.",
        "ds.",
    }

    for name, module in model.named_modules():
        # Freeze all parameters by default
        for param in module.parameters():
            param.requires_grad = False

        if "bn" in name:
            module.track_running_stats = False

        layer_id = name.split(".")[-1]
        if (
            any(layer in name for layer in layers_to_unfreeze)
            and task_id_str == layer_id
        ):
            for param in module.parameters():
                param.requires_grad = True
            if "bn" in name:
                module.track_running_stats = True
        elif "bns." in name:
            bns_id = name.split(".")[-2]
            if task_id_str == bns_id:
                for param in module.parameters():
                    param.requires_grad = True
                module.track_running_stats = True


def max_excluding_outliers_iqr(data: torch.Tensor, k: float = 0.0) -> float:
    """ Finding maximum afer excluding upper and lower quantiles. Used in weight re-norm of classification weights"""
    Q1 = torch.quantile(data, 0.25)
    Q3 = torch.quantile(data, 0.75)
    IQR = Q3 - Q1
    upper_bound = Q3 + k * IQR

    filtered_data = data[(data <= upper_bound)]

    return torch.max(filtered_data).item()
