import numpy as np
import random
import torch
import torch.nn as nn
import os
import inspect
import pickle
import gdown
from network import Actor


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers.
        Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""

    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


def set_seed(random_seed):
    if random_seed <= 0:
        random_seed = np.random.randint(1, 9999)
    else:
        random_seed = random_seed

    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)

    return random_seed


def make_env(env_name, seed):
    import gymnasium as gym
    # openai gym
    env = gym.make(env_name)
    env.action_space.seed(seed)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_bound = [env.action_space.low[0], env.action_space.high[0]]

    env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}

    return env, env_info


def get_learning_info(args, seed):
    env, env_info = make_env(args.env_name, seed)
    device = 'cuda'

    alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
                  'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold,
                  'Carbon_Mini_Drone': args.alpha_threshold}

    thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
    max_action = 1

    t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
    num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())

    kwargs = {
        "env": env,
        "args": args,
        "env_info": env_info,
        "thresholds": thresholds,
        "discount": args.discount,
        "datasize": args.datasize,
        "tau": args.tau,
        "device": device,
        "num_teacher_param": num_teacher_param,
        "noise_clip": args.noise_clip * max_action,
        "policy_freq": args.policy_freq,
        "h": args.h,
    }
    return kwargs


def get_compression_ratio(num_teacher_param, agent):
    kep_w = 0
    for c in agent.actor.children():
        kep_w += c.get_num_remained_weights()
    #

    return kep_w / num_teacher_param


def load_buffer(env_name, level, datasize):
    current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
    try:
        with open(file_path, "rb") as fr:
            buffer = pickle.load(fr)
            buffer.size = datasize
    except FileNotFoundError:
        # Download the file
        if level == 'expert':
            print("Downloading the teacher buffer...")
            if env_name == "Ant-v3":
                file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
            elif env_name == "Walker2d-v3":
                file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
            elif env_name == "HalfCheetah-v3":
                file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
            elif env_name == "Hopper-v3":
                file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
            else:
                raise ValueError("Invalid Environment Name")

            url = f"https://drive.google.com/uc?id={file_id}"
            gdown.download(url, file_path, quiet=False)
            print("Download Complete!")
        elif level == 'medium':
            if env_name == "Ant-v3":
                file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
            elif env_name == "Walker2d-v3":
                file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
            elif env_name == "HalfCheetah-v3":
                file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
            elif env_name == "Hopper-v3":
                file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
            else:
                raise ValueError("Invalid Environment Name")

            url = f"https://drive.google.com/uc?id={file_id}"
            gdown.download(url, file_path, quiet=False)

        else:
            raise ValueError("Invalid Level. Choose from ['expert', 'medium']")

        with open(file_path, "rb") as fr:
            buffer = pickle.load(fr)
            buffer.size = datasize

    return buffer
