import warnings
import os
from copy import deepcopy

warnings.filterwarnings("ignore")

import numpy as np
import torch
from tqdm import tqdm
import wandb

from maml_rl.buffer import MultiTaskBuffer
from maml_rl.anil_optimizer import DifferentiableSGD
from maml_rl.sampler import Sampler


class MetaLearner:
    def __init__(self,
                 env,
                 train_env,
                 agent,
                 observ_dim,
                 action_dim,
                 train_tasks,
                 device,
                 debug=False,
                 **config):
        
        self.env = env
        self.train_env = train_env
        self.agent = agent
        self.train_tasks = train_tasks
        self.debug = debug

        self.num_iterations = config["num_iterations"]
        self.meta_batch_size = config["meta_batch_size"]
        self.num_samples = config["num_samples"]
        self.max_steps = config["max_steps"]

        self.num_adapt_epochs = config["num_adapt_epochs"]
        self.backtrack_iters = config["backtrack_iters"]
        self.backtrack_coeff = config["backtrack_coeff"]
        self.max_kl = config["max_kl"]

        self.sampler = Sampler(
            env=env,
            agent=agent,
            action_dim=action_dim,
            max_step=config["max_steps"],
            device=device,
        )

        self.buffers = MultiTaskBuffer(
            observ_dim=observ_dim,
            action_dim=action_dim,
            agent=agent,
            num_tasks=self.meta_batch_size,
            num_episodes=(self.num_adapt_epochs + 1),  # [num of adapatation for train] + [validation]
            max_size=self.num_samples,
            device=device
        )

        self.inner_optimizer = DifferentiableSGD(
            self.agent.policy,
            lr=config["inner_learning_rate"],
        )
        self.best_return = -np.inf
        if self.debug:
            pass
        else:
            wandb.init(project = 'Meta RL',name = f'ANIL meta training({train_env})')

    def collect_train_data(self, indices, is_eval= False):
        # inner update로 변경된 parameter를 원래대로 되돌리기 위해, 현재 parameter를 저장
        backup_params = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}

        mode = "test" if is_eval else "train"
        print(f"Collecting samples for meta-{mode}")
        for cur_task, task_index in enumerate(tqdm(indices)):

            self.env.reset_task(task_index)
            for cur_adapt in range(self.num_adapt_epochs + 1):

                # 메타-테스트에 대해서는 deterministic한 정책으로 경로 생성
                self.agent.policy.is_deterministic = (
                    True if cur_adapt == self.num_adapt_epochs and is_eval else False
                )
                # self.num_samples만큼 trajectory를 샘플링
                trajs = self.sampler.obtain_samples(max_samples=self.num_samples)
                # sampling된 trajectory를 buffer에 저장
                self.buffers.add_trajs(cur_task, cur_adapt, trajs)

                # inner_loop의 마지막 step이 아닌경우 적용 (마지막 step에서 모은 traj는 업데이트에 사용하지 않음->outer loop 업데이트를 위한 validation sample)
                if cur_adapt < self.num_adapt_epochs:
                    train_batch = self.buffers.get_trajs(cur_task, cur_adapt)

                    inner_loss = self.agent.policy_loss(train_batch)
                    self.inner_optimizer.zero_grad(set_to_none=True)
                    require_grad = cur_adapt < self.num_adapt_epochs - 1
                    # inner loop의 마지막 전까지는 higher-order gradient를 계산할 수 있도록 그래프 유지
                    inner_loss.backward(create_graph=require_grad)

                    with torch.set_grad_enabled(require_grad):
                        self.inner_optimizer.step()
            # adaptation 후 각 task에 대한 파라미터 저장
            self.buffers.add_params(
                cur_task,
                self.num_adapt_epochs,
                dict(self.agent.policy.named_parameters()),
            )

            # 원래 parameter로 되돌리기.
            self.agent.update_model(self.agent.policy, backup_params)
            self.agent.policy.is_deterministic = False

    def meta_surrogate_loss(self, set_grad, policy_params_for_eval=None, meta_init_params=None):
        """
        ANIL meta surrogate:
        - inner loop는 항상 meta_init_params(θ)에서 재현하여 θ_i를 만듦
        - old_policy ← θ_i
        - 평가 policy는 policy_params_for_eval(있으면 θ′, 없으면 θ)
        이렇게 해야 라인서치 중에도 KL/손실이 후보 파라미터 변화에 반응한다.
        """
        losses, kls, entropies = [], [], []

        # θ(메타 초기화) 없으면 현재 policy를 복사해서 사용
        if meta_init_params is None:
            meta_init_params = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}

        # 함수 진입 시 현재 파라미터 보관(나중에 복원)
        preserve_params = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}

        for cur_task in range(self.meta_batch_size):
            # --- 항상 θ에서 시작해 inner loop 재현 ---
            self.agent.update_model(self.agent.policy, meta_init_params)

            for cur_adapt in range(self.num_adapt_epochs):
                require_grad = cur_adapt < self.num_adapt_epochs - 1 or set_grad
                train_batch = self.buffers.get_trajs(cur_task, cur_adapt)

                # ANIL: 내부는 backbone detach 정책(기존 patch된 policy_loss 기본 False)
                inner_loss = self.agent.policy_loss(train_batch)  # is_meta_loss=False
                self.inner_optimizer.zero_grad(set_to_none=True)
                inner_loss.backward(create_graph=require_grad)
                with torch.set_grad_enabled(require_grad):
                    self.inner_optimizer.step()

            # θ_i 저장 → old_policy에 탑재
            theta_i = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}
            self.agent.update_model(self.agent.old_policy, theta_i)

            # 평가용 policy 설정: 후보(θ′)가 오면 그걸로, 없으면 θ로
            if policy_params_for_eval is not None:
                self.agent.update_model(self.agent.policy, policy_params_for_eval)
            else:
                self.agent.update_model(self.agent.policy, meta_init_params)

            # valid batch로 meta loss / KL / entropy
            valid_batch = self.buffers.get_trajs(cur_task, self.num_adapt_epochs)
            loss = self.agent.policy_loss(valid_batch, is_meta_loss=True)
            kl = self.agent.kl_divergence(valid_batch)
            entropy = self.agent.compute_policy_entropy(valid_batch)

            losses.append(loss)
            kls.append(kl)
            entropies.append(entropy)

        # 함수 종료 전에 들어올 때 파라미터로 복원
        self.agent.update_model(self.agent.policy, preserve_params)

        return torch.stack(losses).mean(), torch.stack(kls).mean(), torch.stack(entropies).mean()




    def meta_update(self):
        # θ(메타 초기화) 스냅샷
        meta_init_params = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}

        # 1) 현재 θ에서의 meta surrogate/kl (backprop 그래프 유지)
        loss_before, kl_before, _ = self.meta_surrogate_loss(
            set_grad=True,
            policy_params_for_eval=None,           # 평가도 θ로
            meta_init_params=meta_init_params,     # inner 재현은 항상 θ에서
        )

        # 2) TRPO step 계산
        gradient = torch.autograd.grad(loss_before, self.agent.policy.parameters(), retain_graph=True)
        gradient = self.agent.flat_grad(gradient)
        Hvp = self.agent.hessian_vector_product(kl_before, self.agent.policy.parameters())

        search_dir = self.agent.conjugate_gradient(Hvp, gradient)
        descent_step = self.agent.compute_descent_step(Hvp, search_dir, self.max_kl)
        loss_before = loss_before.detach()

        # 3) Backtracking line search (항상 θ에서 시작해 다양한 step 크기 적용)
        for i in range(self.backtrack_iters):
            ratio = self.backtrack_coeff ** i

            # θ에서 시작해 step 적용
            self.agent.update_model(self.agent.policy, meta_init_params)
            for params, step in zip(self.agent.policy.parameters(), descent_step):
                params.data.add_(step, alpha=-ratio)  # minimize loss

            # 현재 후보 θ′ 스냅샷
            candidate_params = {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}

            # θ는 inner 재현, 평가만 θ′로 수행
            loss_after, kl_after, policy_entropy = self.meta_surrogate_loss(
                set_grad=False,
                policy_params_for_eval=candidate_params,   # 평가용: θ′
                meta_init_params=meta_init_params,         # inner 재현: θ
            )

            print(kl_after)
            is_improved = loss_after < loss_before
            is_constrained = kl_after <= self.max_kl
            print(f"{i}-Backtracks | Loss {loss_after:.4f} < Loss_old {loss_before:.4f} : {is_improved} | "
                f"KL {kl_after:.4f} <= maxKL {self.max_kl:.4f} : {is_constrained}")

            if is_improved and is_constrained:
                print(f"Update meta-policy through {i+1} backtracking line search step(s)")
                # 이미 self.agent.policy는 candidate_params 상태라 그대로 유지하고 종료
                break

            # 실패 시 다음 backtrack을 위해 다시 θ로 되돌림
            if i == self.backtrack_iters - 1:
                # 마지막까지 실패하면 θ 유지(업데이트 스킵)
                print("Keep current meta-policy skipping meta-update")
                self.agent.update_model(self.agent.policy, meta_init_params)

        self.buffers.clear()
        return dict(
            loss_after=loss_after.item(),
            kl_after=kl_after.item(),
            policy_entropy=policy_entropy.item(),
        )


    def _clone_policy_params(self):
        # 현재 self.agent.policy(named_parameters) 스냅샷을 안전하게 복제
        return {k: p.detach().clone() for k, p in self.agent.policy.named_parameters()}


    def meta_train(self):
        for iteration in range(self.num_iterations):
            print(f"\n=============== Iteration {iteration} ===============")
            indices = np.random.randint(len(self.train_tasks), size=self.meta_batch_size)
            self.collect_train_data(indices)
            self.meta_update()
            return_after_grad = self.meta_test(iteration)
            if not os.path.exists(f'./anil_policy/{self.train_env}'):
                os.makedirs(f'./anil_policy/{self.train_env}')
            if return_after_grad > self.best_return:
                self.best_return = return_after_grad
                torch.save(self.agent.policy.state_dict(),f'./anil_policy/{self.train_env}/anil_policy({iteration}).pt')


    def meta_test(self,iter):

        returns_before_grad = []
        returns_after_grad = []

        self.collect_train_data(np.array(self.train_tasks), is_eval=True)

        for task in range(len(self.train_tasks)):
            batch_before_grad = self.buffers.get_trajs(task, 0)
            batch_after_grad = self.buffers.get_trajs(task, self.num_adapt_epochs)

            rewards_before_grad = batch_before_grad["rewards"][: self.max_steps]
            rewards_after_grad = batch_after_grad["rewards"][: self.max_steps]
            returns_before_grad.append(torch.sum(rewards_before_grad).item())
            returns_after_grad.append(torch.sum(rewards_after_grad).item())
            if self.debug:
                pass
            else:
                wandb.log({f"Task {task} returns_b":returns_before_grad[-1]},step=iter)
                wandb.log({f"Task {task} returns_a":returns_after_grad[-1]},step=iter)

        self.buffers.clear()

        return_before_grad = sum(returns_before_grad) / len(self.train_tasks)
        return_after_grad = sum(returns_after_grad) / len(self.train_tasks)
        if self.debug:
            pass
        else:
            wandb.log({f"Task mean return_b":return_before_grad},step=iter)
            wandb.log({f"Task mean return_a":return_after_grad},step=iter)
        return return_after_grad
        