import os
import torch
import random
import numpy as np
import yaml


from omegaconf import DictConfig, OmegaConf

import json
import random

import torch
import torch.nn.functional as F
import torch.nn as nn

import math
from torch.nn.functional import one_hot

from clip import clip
from pathlib import Path

###############################
######## config utils #########
###############################

def save_config(config: DictConfig) -> None:
    OmegaConf.save(config, "config.yaml")


def get_workdir(path):
    p = Path(path)
    parts = p.parts

    if "results" in parts:
        idx = parts.index("results")
        workdir = Path(*parts[:idx]) 
    else:
        workdir = p / "results"
        workdir.mkdir(parents=True, exist_ok=True)
    return str(workdir)       
    #workdir_idx = split_path.index("results")
    #return "/".join(split_path[:workdir_idx])



###############################
######### data utils ##########
###############################
def get_class_order(cfg, file_name=None) -> list:
    if file_name is not None:
        with open(file_name, "r+") as f:
            data = yaml.safe_load(f)
            return data["class_order"]
        
    nb_classes = cfg.num_classes
    assert nb_classes % cfg.increment == 0, f"invalid increment"
    cfg.num_tasks = cfg.num_classes // cfg.increment
    classes_per_task = cfg.increment

    labels = [i for i in range(nb_classes)]

    if cfg.shuffle:
        # Split labels into groups
        grouped_labels = [labels[i:i + classes_per_task] for i in range(0, nb_classes, classes_per_task)]

        # Shuffle the order of the groups
        random.shuffle(grouped_labels)

        # Shuffle elements within each group
        for group in grouped_labels:
            random.shuffle(group)

        # Flatten the list of groups back into a single list
        shuffled_labels = [label for group in grouped_labels for label in group]
        labels = shuffled_labels

        print("Shuffled Scenario")

    elif cfg.reverse:
        labels.reverse()
        print("Ascending Scenario")
        
    else:
        print("Descending Scenario")

    return labels


def get_dataset_class_names(path, dataset_name, long=False):
    with open(os.path.join(path, f"{dataset_name}_classes.txt"), "r") as f:
        lines = f.read().splitlines()
    return [line.split("\t")[-1] for line in lines]

def get_classes_names_cur_task(task_id, classes_names, initial_increment, increment):
    if task_id == 0:
        return classes_names[:initial_increment]
    else:
        start_id = initial_increment + (task_id-1)*increment
        end_id = start_id + increment
        return classes_names[start_id : end_id]



###############################
######## train utils ##########
###############################
def assign_learning_rate(param_group, new_lr):
    param_group["lr"] = new_lr

def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length

def cosine_lr(optimizer, base_lrs, warmup_length, steps):
    if not isinstance(base_lrs, list):
        base_lrs = [base_lrs for _ in optimizer.param_groups]
    assert len(base_lrs) == len(optimizer.param_groups)

    def _lr_adjuster(step):
        for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
            if step < warmup_length:
                lr = _warmup_lr(base_lr, warmup_length, step)
            else:
                e = step - warmup_length
                es = steps - warmup_length
                lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
            assign_learning_rate(param_group, lr)

    return _lr_adjuster

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)