import warnings
import os
from copy import deepcopy

warnings.filterwarnings("ignore")

import numpy as np
import torch
from tqdm import tqdm
import wandb

from maml_rl.buffer import MultiTaskBuffer
from maml_rl.optimizer import DifferentiableSGD
from maml_rl.sampler import Sampler


class MetaLearner:
    def __init__(self,
                 env,
                 train_env,
                 agent,
                 observ_dim,
                 action_dim,
                 train_tasks,
                 device,
                 debug=False,
                 **config):
        
        self.env = env
        self.train_env = train_env
        self.agent = agent
        self.train_tasks = train_tasks
        self.debug = debug

        self.num_iterations = config["num_iterations"]
        self.meta_batch_size = config["meta_batch_size"]
        self.num_samples = config["num_samples"]
        self.max_steps = config["max_steps"]

        self.num_adapt_epochs = config["num_adapt_epochs"]
        self.backtrack_iters = config["backtrack_iters"]
        self.backtrack_coeff = config["backtrack_coeff"]
        self.max_kl = config["max_kl"]

        self.sampler = Sampler(
            env=env,
            agent=agent,
            action_dim=action_dim,
            max_step=config["max_steps"],
            device=device,
        )

        self.buffers = MultiTaskBuffer(
            observ_dim=observ_dim,
            action_dim=action_dim,
            agent=agent,
            num_tasks=self.meta_batch_size,
            num_episodes=(self.num_adapt_epochs + 1),  # [num of adapatation for train] + [validation]
            max_size=self.num_samples,
            device=device
        )

        self.inner_optimizer = DifferentiableSGD(
            self.agent.policy,
            lr=config["inner_learning_rate"],
        )
        self.best_return = -np.inf
        if self.debug:
            pass
        else:
            wandb.init(project = 'Meta RL',name = f'MAML meta training({train_env})')

    def collect_train_data(self, indices, is_eval= False):
        # inner update로 변경된 parameter를 원래대로 되돌리기 위해, 현재 parameter를 저장
        backup_params = dict(self.agent.policy.named_parameters())

        mode = "test" if is_eval else "train"
        print(f"Collecting samples for meta-{mode}")
        for cur_task, task_index in enumerate(tqdm(indices)):

            self.env.reset_task(task_index)
            for cur_adapt in range(self.num_adapt_epochs + 1):

                # 메타-테스트에 대해서는 deterministic한 정책으로 경로 생성
                self.agent.policy.is_deterministic = (
                    True if cur_adapt == self.num_adapt_epochs and is_eval else False
                )
                # self.num_samples만큼 trajectory를 샘플링
                trajs = self.sampler.obtain_samples(max_samples=self.num_samples)
                # sampling된 trajectory를 buffer에 저장
                self.buffers.add_trajs(cur_task, cur_adapt, trajs)

                # inner_loop의 마지막 step이 아닌경우 적용 (마지막 step에서 모은 traj는 업데이트에 사용하지 않음->outer loop 업데이트를 위한 validation sample)
                if cur_adapt < self.num_adapt_epochs:
                    train_batch = self.buffers.get_trajs(cur_task, cur_adapt)

                    inner_loss = self.agent.policy_loss(train_batch)
                    self.inner_optimizer.zero_grad(set_to_none=True)
                    require_grad = cur_adapt < self.num_adapt_epochs - 1
                    # inner loop의 마지막 전까지는 higher-order gradient를 계산할 수 있도록 그래프 유지
                    inner_loss.backward(create_graph=require_grad)

                    with torch.set_grad_enabled(require_grad):
                        self.inner_optimizer.step()
            # adaptation 후 각 task에 대한 파라미터 저장
            self.buffers.add_params(
                cur_task,
                self.num_adapt_epochs,
                dict(self.agent.policy.named_parameters()),
            )

            # 원래 parameter로 되돌리기.
            self.agent.update_model(self.agent.policy, backup_params)
            self.agent.policy.is_deterministic = False

    def meta_surrogate_loss(self, set_grad):
        losses, kls, entropies = [], [], []
        
        # meta update 전 원래 parameter를 보관
        backup_params = dict(self.agent.policy.named_parameters())

        for cur_task in range(self.meta_batch_size):
            for cur_adapt in range(self.num_adapt_epochs):
                # require_grad = True일 동안에는 higer-order gradient를 계산할 수 있도록 그래프를 유지
                require_grad = cur_adapt < self.num_adapt_epochs - 1 or set_grad
                train_batch = self.buffers.get_trajs(cur_task, cur_adapt)

                # inner_loss는 단순한 reinforce 스타일의 -Advantage*log_prob loss
                inner_loss = self.agent.policy_loss(train_batch)
                self.inner_optimizer.zero_grad(set_to_none=True)
                inner_loss.backward(create_graph=require_grad)

                # set_grad 인자에 따라 마지막 스텝에만 상위 그래디언트(그래프)를 남길지 결정
                with torch.set_grad_enabled(require_grad):
                    self.inner_optimizer.step()

            # Surrogate 손실 계산을 위해 line search 초기 정책으로 초기화
            # 해당 task에 해당하는 inner loop update된 parameter 불러오기 (collect_train_data에서)
            valid_params = self.buffers.get_params(cur_task, self.num_adapt_epochs)
            # inner loop update된 parameter로 agent.oldpolicy 업데이트
            # 여기서 old_policy는 train set으로만 학습된 policy이다.
            self.agent.update_model(self.agent.old_policy, valid_params)

            # 메타-배치 태스크에 대한 메타러닝 손실로서 평가경로의 surrogage 손실 계산
            # 마지막 inner_loop에서의 batch sampling
            valid_batch = self.buffers.get_trajs(cur_task, self.num_adapt_epochs)
            # old_log_prob과 current_log_prob의 비율을 이용해 advantage를 곱하는 방식의 loss
            # 이 때, policy는 valid set으로 한번 더 update되며, old_policy는 \theta_i이고
            # 여기서 구해지는 loss는 \theta_i를 통해 valid set에 대해 구한 loss이다.
            loss = self.agent.policy_loss(valid_batch, is_meta_loss=True)
            losses.append(loss)

            # 배치 태스크에 대한 평가 경로의 평균 KL divergence 계산
            # old_policy와 current_policy사이의 kl divergence 계산
            kl = self.agent.kl_divergence(valid_batch)
            kls.append(kl)

            # 배치 태스크에 대한 평가 경로의 평균 정책 엔트로피 계산
            entropy = self.agent.compute_policy_entropy(valid_batch)
            entropies.append(entropy)

            self.agent.update_model(self.agent.policy, backup_params)
        return torch.stack(losses).mean(), torch.stack(kls).mean(), torch.stack(entropies).mean()

    def meta_update(self):
        # 외부 루프 (outer loop)
        # Line search를 시작하기 위한 첫 경사하강 스텝 계산
        loss_before, kl_before, _ = self.meta_surrogate_loss(set_grad=True)

        # meta_loss를 이용해서 meta_parameter를 업데이트 하기 위한 gradient 계산
        gradient = torch.autograd.grad(loss_before, self.agent.policy.parameters(), retain_graph=True)
        # gradient를 펴서 한 줄로 만들기
        gradient = self.agent.flat_grad(gradient)
        Hvp = self.agent.hessian_vector_product(kl_before, self.agent.policy.parameters())

        search_dir = self.agent.conjugate_gradient(Hvp, gradient)
        descent_step = self.agent.compute_descent_step(Hvp, search_dir, self.max_kl)
        loss_before.detach_()

        # Line search 역추적을 통한 파라미터 업데이트
        backup_params = deepcopy(dict(self.agent.policy.named_parameters()))
        for i in range(self.backtrack_iters):
            ratio = self.backtrack_coeff**i

            for params, step in zip(self.agent.policy.parameters(), descent_step):
                params.data.add_(step, alpha=-ratio)

            loss_after, kl_after, policy_entropy = self.meta_surrogate_loss(set_grad=False)
            print(kl_after)
            # KL 제약조건을 만족할 경우 정책 업데이트
            is_improved = loss_after < loss_before
            is_constrained = kl_after <= self.max_kl
            print(f"{i}-Backtracks | Loss {loss_after:.4f} < Loss_old {loss_before:.4f} : ", end="")
            print(f"{is_improved} | KL {kl_after:.4f} <= maxKL {self.max_kl:.4f} : {is_constrained}")

            if is_improved and is_constrained:
                print(f"Update meta-policy through {i+1} backtracking line search step(s)")
                break

            self.agent.update_model(self.agent.policy, backup_params)

            if i == self.backtrack_iters - 1:
                print("Keep current meta-policy skipping meta-update")

        self.buffers.clear()
        return dict(
            loss_after=loss_after.item(),
            kl_after=kl_after.item(),
            policy_entropy=policy_entropy.item(),
        )

    def meta_train(self):
        for iteration in range(self.num_iterations):
            print(f"\n=============== Iteration {iteration} ===============")
            indices = np.random.randint(len(self.train_tasks), size=self.meta_batch_size)
            self.collect_train_data(indices)
            self.meta_update()
            return_after_grad = self.meta_test(iteration)
            if not os.path.exists(f'./maml_policy/{self.train_env}'):
                os.makedirs(f'./maml_policy/{self.train_env}')
            if return_after_grad > self.best_return:
                self.best_return = return_after_grad
                torch.save(self.agent.policy.state_dict(),f'./maml_policy/{self.train_env}/maml_policy({iteration}).pt')


    def meta_test(self,iter):

        returns_before_grad = []
        returns_after_grad = []

        self.collect_train_data(np.array(self.train_tasks), is_eval=True)

        for task in range(len(self.train_tasks)):
            batch_before_grad = self.buffers.get_trajs(task, 0)
            batch_after_grad = self.buffers.get_trajs(task, self.num_adapt_epochs)

            rewards_before_grad = batch_before_grad["rewards"][: self.max_steps]
            rewards_after_grad = batch_after_grad["rewards"][: self.max_steps]
            returns_before_grad.append(torch.sum(rewards_before_grad).item())
            returns_after_grad.append(torch.sum(rewards_after_grad).item())
            wandb.log({f"Task {task} returns_b":returns_before_grad[-1]},step=iter)
            wandb.log({f"Task {task} returns_a":returns_after_grad[-1]},step=iter)

        self.buffers.clear()

        return_before_grad = sum(returns_before_grad) / len(self.train_tasks)
        return_after_grad = sum(returns_after_grad) / len(self.train_tasks)

        wandb.log({f"Task mean return_b":return_before_grad},step=iter)
        wandb.log({f"Task mean return_a":return_after_grad},step=iter)
        return return_after_grad
        