#import ray
import wandb
import os

from agent.workers.DreamerWorker import DreamerWorker
import numpy as np, random
import torch
import matplotlib.pyplot as plt


def plot_rewards(rewards, q, label):

    avg_rew = []
    j = 0
    
    while(j < len(rewards) - q):

      x = rewards[j:j+q]
      sum1 = np.sum(np.array(x)) /q
      avg_rew.append(sum1)
      j = j+1

    plt.plot(avg_rew, label = label)


class DreamerRunner:

    def __init__(self, env_config, learner_config, controller_config, n_workers):
        self.n_workers = n_workers
        self.learner = learner_config.create_learner(env_config)
        self.worker = DreamerWorker(1, env_config, controller_config)

    def run(self, max_steps=10 ** 10, max_episodes=10 ** 10):
        cur_steps, cur_episode = 0, 0
        stats = []
        win_count = 0
        win_rates = []
        log_interval = 20
        win_rate_at_each_step = []
        rewards_at_each_step = []
        total_rewards_at_each_step = []

        wandb.define_metric("win_rate")
        wandb.define_metric("steps")
        wandb.define_metric("reward", step_metric="steps")
        wandb.define_metric("aver_step_rewards", step_metric="steps")
        wandb.define_metric("total_rewards", step_metric="steps")
        train_agent_id = 0

        while True:

            if train_agent_id == 0:
                self.learner.fix_agent(1)  
            else:
                self.learner.fix_agent(0)  

            rollouts, info = self.worker.run(self.learner.models, self.learner.actors)
            self.learner.step(rollouts, train_agent_id)
            self.learner.unfix_agent(1 if train_agent_id == 0 else 0)
            train_agent_id = (train_agent_id + 1) % 2

            print("win_flag",info["win_flag"])

            cur_steps += info["steps_done"]
            win_count += info["win_flag"] # win: 1, lose: 0
            cur_episode += 1

            if cur_episode % log_interval == 0: # log
                win_rate = win_count / log_interval
                win_rates.append(win_rate)
                wandb.log({'win_rate': win_rate})
                win_rate_at_each_step.append((cur_steps if cur_episode % log_interval == 0 else 'N/A', win_rate)) 
                win_count = 0

            wandb.log({'reward': info["reward"], 'steps': cur_steps})
            wandb.log({'aver_step_rewards': info["aver_step_rewards"], 'steps': cur_steps})
            wandb.log({'total_rewards': info["total_rewards"], 'steps': cur_steps})

            total_rewards_at_each_step.append((cur_steps, info["total_rewards"])) 
            rewards_at_each_step.append((cur_steps, info["reward"])) 

            np.save('win_rate_at_each_step.npy', np.array(win_rate_at_each_step))
            np.save('total_rewards_at_each_step.npy', np.array(total_rewards_at_each_step))
            np.save('rewards_at_each_step.npy', np.array(rewards_at_each_step))
            
            stats.append(info["reward"])
            if(len(stats)%1==0):
              np.save('mamba_rew', np.array(stats))
            plot_rewards(stats, 2, 'Reward')
            plt.legend()
            plt.xlabel('Episodes')
            plt.ylabel('Episode_Rewards')
            plt.savefig('mamba.png')
            plt.close()          

            if cur_episode >= max_episodes or cur_steps >= max_steps:
                break
            if(cur_episode%100==0):
               self.worker.env.close()

