import PIL.Image

import gym
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

def evaluation(diffusion_env, env_name):
    evaluate_steps = 1
    samples = 100
    env = gym.make(env_name)
    normalizer = diffusion_env.dataset.normalizer['observations']
    act_normalizer = diffusion_env.dataset.normalizer['actions']
    action_dim = diffusion_env.action_space.shape[0]
    obs_dim = diffusion_env.observation_space.shape[0]

    noises = torch.randn((5, 1, obs_dim), device="cuda")
    cur_loss = np.zeros(shape=(samples * evaluate_steps, 3))
    for i in tqdm(range(samples)):
        env.reset()
        init_state = diffusion_env.reset()
        unnorm_state = normalizer.unnormalize(init_state)

        qpos = np.concatenate((np.array([0]), unnorm_state[:obs_dim//2]))
        qvel = unnorm_state[obs_dim//2:]
        env.set_state(qpos, qvel)
        cur_state = init_state

        cumu_std = 0
        for j in range(evaluate_steps):
            action = np.float32(np.random.uniform(-1, 1, action_dim))
            norm_action = act_normalizer.normalize(action)
            e_states = []
            for k in range(5):
                diffusion_env.setstate(cur_state)
                e_state, e_reward, e_done, _ = diffusion_env.original_step(norm_action, apply_noise=noises[k])
                e_states.append(e_state)
            e_state_std = np.std(np.array(e_states), axis=0).mean()
            e_state_unnorm = normalizer.unnormalize(e_state)
            r_state, r_reward, r_done, _ = env.step(action)

            cur_loss[i*evaluate_steps + j][0] = np.linalg.norm(e_state_unnorm - r_state)
            cur_loss[i*evaluate_steps + j][1] = np.abs(e_reward - r_reward)
            cumu_std += e_state_std

            cur_loss[i*evaluate_steps + j][2] = e_state_std
            cur_state = e_state
            if e_done or r_done:
                break
    final_loss = np.array(cur_loss)
    state_loss = final_loss[final_loss[:, 2] > 0][:, 0]

    linear_regression(final_loss, env_name)

    print(np.mean(state_loss))
    for i in range(50):
        state_loss = final_loss[(final_loss[:, 2] > 0) & (final_loss[:, 2] < 0.05 - 0.001*i)][:, 0]
        reward_loss = final_loss[(final_loss[:, 2] > 0) & (final_loss[:, 2] < 0.05 - 0.001*i)][:, 1]
        print(0.05-0.001*i, np.mean(state_loss), np.mean(reward_loss), np.shape(state_loss))
    reward_loss = final_loss[:, 1]

def linear_regression(final_loss, env_name):
    a = final_loss[:, 2]
    b = final_loss[:, 0]
    a_reshaped = a.reshape(-1, 1)
    model = LinearRegression()
    model.fit(a_reshaped, b)
    b_pred = model.predict(a_reshaped)
    plt.scatter(a, b, color='blue', label='Data points')
    plt.plot(a, b_pred, color='red', label='Linear regression line')
    plt.xlabel('Discrepancy')
    plt.title('Linear Regression of {}'.format(env_name))
    plt.legend()
    plt.savefig('./csv/linear_regression_{}.pdf'.format(env_name))

def reward_test(diffusion_env, env_name):
    env = gym.make(env_name)
    loss = []
    action_dim = diffusion_env.action_space.shape[0]
    for i in tqdm(range(10000)):
        state = env.reset()
        while True:
            action = np.float32(np.random.uniform(-1, 1, action_dim))
            new_state, r_reward, r_done, _ = env.step(action)
            normalizer = diffusion_env.dataset.normalizer['observations']
            norm_state = normalizer(state)
            norm_new_state = normalizer(new_state)
            e_reward = diffusion_env.reward_test(np.float32(norm_state), action, np.float32(norm_new_state))
            rewardloss = np.abs(e_reward - r_reward)
            loss.append(rewardloss)
            state = new_state
            if r_done:
                break
    a = np.array(loss).mean()
    print(a)