import torch
import numpy as np
import math
from rewarder import cosine_distance, euclidean_distance
from torchvision.transforms import Normalize

class AutomaticDiscountScheduling:
    def __init__(self, horizon, alpha, threshold, progress_start, max_progress_delta, ref_score_percentile, agent_score_percentile, device):
        self.horizon = horizon
        self.alpha = alpha
        self.threshold = threshold
        self.progress_start = progress_start
        self.max_progress_delta = max_progress_delta
        self.ref_score_percentile = ref_score_percentile
        self.agent_score_percentile = agent_score_percentile
        self.device = device

        self.progress = int(progress_start * horizon)
        
    def init_encoder(self, cost_encoder, use_kendall=True, use_clip=False, text_feature=None):
        self.cost_encoder = cost_encoder
        self.use_kendall = use_kendall
        self.use_clip = use_clip
        self.img_norm = Normalize(mean=torch.tensor([123.675, 116.28, 103.53]),
                                    std=torch.tensor([58.395, 57.12, 57.375]))
        if use_clip:
            self.text_feature = text_feature
    
    def init_demos(self, demos, obs_type, cost_encoder):
        self.obs_type = obs_type
        self.cost_encoder = cost_encoder
        self.demos = demos
        max_len = max(len(d) for d in demos)
        for i in range(len(demos)):
            current_len = len(demos[i])
            print(f"current_len: {current_len}, max_len: {max_len}")
            if current_len < max_len:
                pad = demos[i][-1:].repeat(max_len - current_len, 1)
                demos[i] = torch.cat([demos[i], pad], dim=0)
        self.use_kendall = False
        self.use_clip = False

        scores = []
        for i in range(len(demos)):
            if not isinstance(demos[i], torch.Tensor):
                demos[i] = torch.tensor(demos[i])
            demos[i] = demos[i].to(self.device).float()

        for i in range(len(demos)):
            for j in range(len(demos)):
                if j != i:
                    demo1 = demos[i]
                    demo2 = demos[j]
                    if obs_type == 'pixels':
                        cost_matrix = cosine_distance(demo1, demo2) 
                    elif obs_type == 'features':
                        cost_matrix = euclidean_distance(demo1, demo2)
                    else:
                        raise NotImplementedError
                    score = torch.zeros(cost_matrix.shape[0])
                    for k in range(cost_matrix.shape[0]):
                        pos = cost_matrix[:k + 1, :k + 1].min(1)[1]
                        score[k] = longest_increasing_subsequence(pos)
                    scores.append(score)
        scores = np.stack(scores, axis=0)
        self.ref_score = np.percentile(scores, self.ref_score_percentile, axis=0)

        import matplotlib.pyplot as plt
        plt.clf()
        plt.bar(range(self.ref_score.shape[0]), self.ref_score)
        plt.savefig(f'ref_score')
        return

    def get_discount(self):
        discount = math.exp(math.log(self.alpha) / self.progress)
        return discount
    
    def compute_cost(self, observation):
        obs = torch.tensor(observation).to(self.device).float()
        obs = obs.detach()
        if self.use_kendall:
            obs = obs[:, -3:]
            t, c, h, w = obs.shape
            idx = np.linspace(0, t - 1, 32).astype(int)
            obs = obs[idx]
            obs = self.img_norm(obs)
            with torch.no_grad():
                _, kendall_taus, _ = self.cost_encoder(obs.unsqueeze(0))
            kendall_tau = kendall_taus[0]
            idx = np.linspace(0, 31, t).astype(int)
            kendall_tau = kendall_tau[idx]
            
            return kendall_tau
        
        if self.use_clip:
            obs = obs[:, -3:]
            n, tc, h, w = obs.shape
            t = tc // 3
            obs = obs.view(n, t, 3, h, w)
            obs = obs.view(-1, 3, h, w)
            obs = self.img_norm(obs)
            obs = obs.view(n, t, 3, h, w)
            with torch.no_grad():
                original_value = self.cost_encoder(obs, self.text_feature)
            value = original_value[0]
            computed_reward = torch.zeros_like(value)
            for n in range(1, value.shape[0]):
                i = int(math.log2(n)) + 1  # largest i so that 2^(i-1) <= n
                s = 0
                for k in range(1, i):
                    s += value[n - 2 ** (k - 1)] / 2 ** (k - 1)
                computed_reward[n] = ((2 - 1 / 2 ** (i - 1)) * value[n] - s) / i
            reward = computed_reward

            return reward.cpu().numpy()
        
        cost_matrixs = list()
        if self.obs_type == 'pixels':
            with torch.no_grad():
                obs = self.cost_encoder(obs)
        print('len(self.demos):', len(self.demos))
        for demo in self.demos:
            exp = torch.tensor(demo).to(self.device).float()
            exp = exp.detach()
            if exp.shape[0] < obs.shape[0]:
                needed = obs.shape[0] - exp.shape[0]
                exp = torch.cat([exp, exp[-1:].repeat(needed, *([1] * (exp.dim() - 1)))], dim=0)
            if obs.shape[0] < exp.shape[0]:
                exp = exp[:obs.shape[0]]
                
            assert obs.shape[0] == exp.shape[0]
            assert obs.shape[1:] == exp.shape[1:]
            if self.obs_type == 'pixels':
                cost_matrix = cosine_distance(obs, exp)
            elif self.obs_type == 'features':
                cost_matrix = euclidean_distance(obs, exp)
            else:
                raise NotImplementedError
            # print('cost_matrix.shape:', cost_matrix.shape)
            cost_matrixs.append(cost_matrix)
        return cost_matrixs
    
    def update(self, cost_matrixs):
        if self.use_kendall:
            for i in range(self.max_progress_delta):
                scores = cost_matrixs[:self.progress]
                scores = torch.Tensor(scores)
                match_score = scores.mean()
                if self.progress < self.horizon:
                    if match_score >= self.threshold:
                        self.progress += 1
                else:
                    break
            discount = self.get_discount()
            metrics = {
                'discount': discount,
                'discount_log': math.log(discount),
                'discount_progress': self.progress,
                'match_score': match_score,
            }
            return discount, metrics

        if self.use_clip:
            for i in range(self.max_progress_delta):
                # 对每条数据分别计算到progress位置的值大于0的概率
                positive_probs = []
                for cost_matrix in cost_matrixs:
                    scores = cost_matrix[:self.progress]
                    scores = torch.tensor(scores)  # 先转换为torch tensor
                    positive_prob = (scores > 0).float().mean()
                    positive_probs.append(positive_prob)
                # 计算所有数据的中位数概率
                median_positive_prob = torch.tensor(positive_probs).median()
                if self.progress < self.horizon:
                    if median_positive_prob >= 0.9:  # 如果中位数概率超过90%
                        self.progress += 1
                else:
                    break
            discount = self.get_discount()
            metrics = {
                'discount': discount,
                'discount_log': math.log(discount),
                'discount_progress': self.progress,
                'median_positive_prob': median_positive_prob,
            }
            return discount, metrics

        for i in range(self.max_progress_delta):
            match_scores = []
            # print('len(cost_matrixs):', len(cost_matrixs))
            for i in range(len(cost_matrixs)):
                costs = cost_matrixs[i]
                for j in range(len(costs)):
                    cost = costs[j]
                    # print('len(cost):', len(cost), 'cost[0].shape:', cost[0].shape)
                    pos = cost[:self.progress, :self.progress].argmin(1)
                    match_scores.append(longest_increasing_subsequence(pos))
            match_score = np.percentile(match_scores, self.agent_score_percentile)
            ref_score = self.ref_score[self.progress - 1]
            if self.progress < self.horizon:
                if match_score >= int(self.threshold * ref_score):
                    self.progress += 1
            else:
                break
        discount = self.get_discount()
        metrics = {
            'discount': discount,
            'discount_log': math.log(discount),
            'discount_progress': self.progress,
            'match_score': match_score,
            'ref_score': ref_score
        }
        return discount, metrics
            

def longest_increasing_subsequence(a):
    dp = np.ones(a.shape[0])
    answer = 1
    for i in range(1, a.shape[0]):
        for j in range(i):
            if a[j] < a[i]:
                dp[i] = max(dp[i], dp[j] + 1)
        answer = max(answer, dp[i])
    return answer
