# -*- coding: utf-8 -*-
import torch
from math import cos, sin, pi
import torch
import math
import matplotlib.pyplot as plt
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)

class evaluate:
    def __init__(self, env, cfg, drawer, simulater,  writer):
        self.cfg = cfg
        self.env = env
        self.eplen = cfg.ep_len
        self.sknum = cfg.sk_num
        self.drawer = drawer
        self.simulater = simulater
        self.interpolate = cfg.copies
        self.new_skillnum = cfg.vnum
        self.g_var = cfg.g_var
        self.writer = writer
        self.state_dim = cfg.state_dim
        self.iter = cfg.train_iter
        self.kde = np.zeros(cfg.train_iter)
        self.entropy = np.zeros(cfg.train_iter)
        self.perf = np.zeros(cfg.train_iter)
        self.perf2 = np.zeros(cfg.train_iter)


    def draw(self, points):
        self.drawer.render(points)
    
    @staticmethod
    def vector_to_polar(X):
        magnitude = torch.sqrt(X[:, 0]**2 + X[:, 1]**2)
    
        angle_radians = torch.atan2(X[:, 1], X[:, 0])
    
        return  magnitude, angle_radians


    @staticmethod
    def interpolate_f(tensor, num):
        initial_tensors = tensor

        interpolated_tensors = []

        for i in range(len(initial_tensors) - 1):
            tensor1 = initial_tensors[i]
            tensor2 = initial_tensors[i + 1]

            for j in range(num):
                interpolated_tensor = tensor1 + (tensor2 - tensor1) * (j / 10)
                interpolated_tensors.append(interpolated_tensor)


        #interpolated_tensors.append(initial_tensors[-1])


        newint = torch.stack(interpolated_tensors)
        #print(newint.size())
        return newint


    def sorting_idx(self, _total_state, sknum):
        total_state = _total_state.transpose(0, 1)
        coord =  total_state[:, -1, :self.state_dim]
        idx_set = total_state[:, -1, self.state_dim:]
        angles = torch.zeros(sknum).to(device)
        for i in range(sknum):
           angles[i] = torch.atan2(coord[i][1], coord[i][0])
        ang, idx = torch.sort(angles, 0)
        new_idxset = torch.cat((idx_set[idx], idx_set[idx][0].unsqueeze(0)), dim = 0)

        afterint = self.interpolate_f(new_idxset, self.interpolate)

        return afterint

    @staticmethod
    def cross_prod(A, B):
        adim = len(A)
        bdim = len(B[0])

        A_ext = A.unsqueeze(2).expand(-1, -1, bdim) 
        B_ext = B.unsqueeze(0).expand(adim, -1, -1)  

        A_3d = torch.cat((A_ext, torch.zeros(adim, 1, bdim).to(device)), dim=1)  
        B_3d = torch.cat((B_ext, torch.zeros(adim, 1, bdim).to(device)), dim=1) 

        cross_prod = torch.cross(A_3d, B_3d, dim=1)

        C = cross_prod[:, 2, :]

        return C


    def per_vector(self, _total_state, vnum):
        ang_array = 2*torch.pi*torch.arange(vnum).to(device)/vnum
        vec_array = torch.cat((torch.cos(ang_array).unsqueeze(-1), torch.sin(ang_array).unsqueeze(-1)), dim = -1)

        dot_mat = _total_state[:, -1, :] @ (vec_array.transpose(0, 1))
        cross_mat = self.cross_prod(_total_state[:, -1, :], vec_array.transpose(0, 1))
        beta = 10
        obj_mat = dot_mat - beta * (cross_mat**2)
        _, max_indices = torch.max(obj_mat, dim=0)
        maxstate = _total_state[max_indices]
        return maxstate

    @staticmethod
    def performance(_total_state, sknum):
        #distance arry
        dist = torch.sqrt(torch.sum(torch.square(_total_state), dim = -1))
        ext_dist = torch.cat((dist, dist[0:1]))
        angles = torch.atan2(_total_state[:, 1], _total_state[:, 0])
        ext_angles = torch.cat((angles, angles[0].unsqueeze(0)), dim = 0)

        area = 0
        for i in range(len(_total_state)):
            area = area + ext_dist[i] * ext_dist[i+1]*torch.sin(torch.abs(ext_angles[i] - ext_angles[i+1]))
        return area/2

    @staticmethod
    def performance2(_total_state):
        #distance arry
        dist = torch.sqrt(torch.sum(torch.square(_total_state), dim = -1))

        return(torch.mean(dist))


    def gaussian_kde_2d(self, traj):
        #17 , 11, 2
        sknum = len(traj)
        #print("tsize = ",traj.size())
        traj = traj.transpose(0, 1)
        # 11, 17, 2
        traj = traj.reshape(-1, self.state_dim)

        n = traj.size(0)
        density = torch.zeros(n)

        diff = traj.unsqueeze(1) - traj.unsqueeze(0)
        distance_squared = torch.sum(diff**2, dim=2)
        #print("ds = ",distance_squared)

        eye = torch.eye(n, dtype=torch.bool)
        distance_squared[eye] = float('inf')
        distance_squared[:sknum, :sknum] = float('inf')
    
        density = torch.mean(torch.exp(-distance_squared / self.g_var), dim=-1)/((torch.pi * self.g_var)**(0.5))
        #print("density", density)
        #indices = torch.arange(n) % 11 == 0
        #density[indices] = 0

        return density

    def eval(self, policy, trajlist, iternal, time):
        window_size = 10
        log_iternal = math.log(iternal+1)
        renderlist = trajlist.transpose(0, 1)[:, :, :self.state_dim]
        idx = self.sorting_idx(trajlist, self.sknum)
        if iternal %100 == 0:

            self.drawer.render(renderlist)

        tmp_sknum = self.sknum * self.interpolate
        (idx_set2, total_state, noise_tensor) = self.simulater.make_traj(policy, tmp_sknum, idxset=idx)
        renderlist = total_state.transpose(0, 1)[:, :, :self.state_dim]
        if iternal %100 == 0:
            self.drawer.render(renderlist)

        max_state = self.per_vector(renderlist, vnum=self.new_skillnum)
        if iternal %100 == 0:
            self.drawer.render(max_state)

        x = np.linspace(0, self.iter, self.iter)
        density = self.gaussian_kde_2d(max_state)
        kde = torch.sum(density)
        self.writer.add_scalar('cgs', kde.item(), iternal)
        self.kde[iternal] = kde.item()
        np.save('main experiment/result/' + self.cfg.algorithm + str(self.cfg.d) + '_kde.npy', self.kde)

        entropy = torch.mean(-torch.log(density))
        self.writer.add_scalar('entropy', entropy.item(), iternal)
        self.entropy[iternal] = entropy.item()
        np.save('main experiment/result/' + self.cfg.algorithm + str(self.cfg.d) + '_entropy.npy', self.entropy)

        m_edge_state = max_state[:, -1, :] # 17, 11, 2 -> 17, 2
        perf = self.performance(m_edge_state, tmp_sknum)
        self.writer.add_scalar('performance', perf.item(), iternal)
        self.perf[iternal] = perf.item()
        np.save('main experiment/result/' + self.cfg.algorithm + str(self.cfg.d) + '_perf.npy', self.perf)

        perf2 = self.performance2(m_edge_state)
        self.writer.add_scalar('performance2', perf2.item(), iternal)
        self.perf2[iternal] = perf2.item()
        np.save('main experiment/result/' + self.cfg.algorithm + str(self.cfg.d) + '_perf2.npy', self.perf2)
        self.writer.add_scalar('training time', time, iternal)


