import numpy as np
import os

import pickle
from collections import defaultdict
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
import torch.nn.init as init
import math

from torch.optim.lr_scheduler import CosineAnnealingLR
from divmorph.config import cfg
from divmorph.algos.ppo.ppo import PPO
from divmorph.algos.ppo.envs import *
from divmorph.algos.ppo.svd_model import ActorCritic
from divmorph.envs.vec_env.running_mean_std import RunningMeanStd

from tools.train_ppo import set_cfg_options

torch.manual_seed(0)
np.random.seed(0)

DEFAULT_OBS_DIM = list(range(13)) + [30, 31] + [41, 42]
DEFAULT_CONTEXT_DIM = list(range(13, 30)) + list(range(30+2, 30+2+9)) + list(range(30+11+2, 30+11+2+9))


class DistillationDataset(Dataset):
    def __init__(self, data, task_name=None):
        self.data = data
        self.task_name = task_name

    def __len__(self):
        return len(self.data['obs'])

    def __getitem__(self, index):
        obs = self.data['obs'][index]
        act = self.data['act'][index]
        act_mean = self.data['act_mean'][index]
        unimal_ids = self.data['unimal_ids'][index]
        val_teacher = self.data['val_teacher'][index]
        if 'hfield' in self.data:
            hfield = self.data['hfield'][index]
        else:
            hfield = torch.zeros(1)
        if 'goal' in self.data:
            goal = self.data['goal'][index]
        else:
            goal = torch.zeros(1)
        if 'obj' in self.data:
            obj = self.data['obj'][index]
        else:
            obj = torch.zeros(1)
        return obs, act, act_mean, hfield, unimal_ids, val_teacher, goal, obj

def lcm(a, b):
    return abs(a * b) // math.gcd(a, b)

def lcm_list(nums):
    result = nums[0]
    for n in nums[1:]:
        result = lcm(result, n)
    return result

class MultiTaskDataLoader:
    def __init__(self, datasets, batch_size=32, shuffle=True, drop_last=False):
        self.datasets = datasets
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.loaders = {
            task: DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
            for task, ds in datasets.items()
        }
        self.task_names = list(datasets.keys())
        
        self.task_batch_counts = {task: len(loader) for task, loader in self.loaders.items()}

        self.total_batches = lcm_list(list(self.task_batch_counts.values()))
        print(f"Total batches per epoch: {self.total_batches}")

    def __iter__(self):
        self.iters = {task: iter(loader) for task, loader in self.loaders.items()}
        self.step_count = 0
        return self

    def __next__(self):
        print(f"Step count: {self.step_count}, Total batches: {self.total_batches}")
        if self.step_count >= self.total_batches:
            raise StopIteration

        batch_dict = {}
        for task in self.task_names:
            try:
                batch_dict[task] = next(self.iters[task])
            except StopIteration:
                self.iters[task] = iter(self.loaders[task])
                batch_dict[task] = next(self.iters[task])

        self.step_count += 1
        return batch_dict

def distill_policy():

    agents = cfg.ENV.WALKERS
    envs = make_vec_envs()
    model = ActorCritic(envs.observation_space, envs.action_space).to(cfg.DEVICE)

    count_param = []
    for name, param in model.named_parameters():
        if 'hfield_encoder' in name:
            continue
        if 'obj_encoder' in name:
            continue
        if 'gate' in name:
            continue
        if 'mu' in name:
            print(name)
            count_param.append(param)
        
    total_params = sum(p.numel() for p in count_param)
    print(f"Total model parameters: {total_params}")
    all_buffers = {}
    all_contexts = {}
    
    for task_folder in cfg.DISTILL.MULTI_SOURCE:
        with open(f"{task_folder}/task.txt", "r", encoding="utf-8") as f:
            task_name = f.read().strip()
        print(task_name)

        all_buffers[task_name] = {
            'obs': [], 
            'act': [], 
            'act_mean': [], 
            'unimal_ids': [], 
            'val_teacher': [],
        }
        if task_name in cfg.DISTILL.HFIELD_TASK:
            print(task_name, "has hfield")
            all_buffers[task_name]['hfield'] = []
        
        if task_name in cfg.DISTILL.GOAL_TASK:
            print(task_name, "has goal")
            all_buffers[task_name]['goal'] = []
        
        if task_name in cfg.DISTILL.OBJ_TASK:
            print(task_name, "has obj")
            all_buffers[task_name]['obj'] = []
        

        all_contexts[task_name] = defaultdict(list)

        rms_path = f"{task_folder}/obs_rms.pkl"
        with open(rms_path, "rb") as f:
            obs_rms_src = pickle.load(f)

        mean_src = obs_rms_src['proprioceptive'].mean
        var_src = obs_rms_src['proprioceptive'].var


        epsilon = 1e-8
        i = 0
        for agent in agents:
            data_path = f'{task_folder}/{agent}.pkl'
            
            if not os.path.exists(data_path):
                print(data_path)
                print("not exist")
                continue
            print (agent)
            with open(data_path, 'rb') as f:
                agent_data = pickle.load(f)
            obs_orig = (agent_data['obs'] * np.sqrt(var_src + epsilon) + mean_src).float()
            agent_data['obs'] = obs_orig

            env = make_env(cfg.ENV_NAME, 0, 0, xml_file=agent)()
            init_obs = env.reset()
            env.close()
            all_contexts[task_name]['context'].append(init_obs['context'])
            all_contexts[task_name]['obs_mask'].append(init_obs['obs_padding_mask'])
            all_contexts[task_name]['act_mask'].append(init_obs['act_padding_mask'])

            if len(cfg.MODEL.PROPRIOCEPTIVE_OBS_TYPES) == 6 and agent_data['obs'].shape[-1] == 624:
                data_size = agent_data['obs'].shape[0]
                new_obs = agent_data['obs'].view(data_size, cfg.MODEL.MAX_LIMBS, -1)
                new_obs = new_obs[:, :, DEFAULT_OBS_DIM]
                agent_data['obs'] = new_obs.view(data_size, -1)
            
            if cfg.DISTILL.SAMPLE_STRATEGY == 'random':
                sample_index = np.random.choice(agent_data['obs'].shape[0], cfg.DISTILL.PER_AGENT_SAMPLE_NUM, replace=False)
            

            for key in ['obs', 'act', 'act_mean', 'val_teacher', 'hfield', 'goal', 'obj']:
                if key not in all_buffers[task_name]:
                    continue
                if cfg.DISTILL.SAMPLE_STRATEGY == 'random':
                    all_buffers[task_name][key].append(agent_data[key][sample_index])
                elif cfg.DISTILL.SAMPLE_STRATEGY == 'timestep_first':
                    data_size = agent_data[key].shape[0]
                    feat_dim = agent_data[key].shape[-1]
                    all_buffers[task_name][key].append(agent_data[key].reshape(-1, 64, feat_dim).permute(1, 0, 2).reshape(data_size, feat_dim)[:cfg.DISTILL.PER_AGENT_SAMPLE_NUM])
                elif cfg.DISTILL.SAMPLE_STRATEGY == 'env_first':
                    all_buffers[task_name][key].append(agent_data[key][:cfg.DISTILL.PER_AGENT_SAMPLE_NUM])
                else:
                    raise ValueError("Unsupported sample strategy")
                
            
            all_buffers[task_name]['unimal_ids'].append(torch.ones(cfg.DISTILL.PER_AGENT_SAMPLE_NUM, dtype=torch.long) * i)
            i += 1
        all_contexts[task_name]['context'] = torch.from_numpy(np.stack(all_contexts[task_name]['context'])).float().to(cfg.DEVICE)
        all_contexts[task_name]['obs_mask'] = torch.from_numpy(np.stack(all_contexts[task_name]['obs_mask'])).float().to(cfg.DEVICE)
        all_contexts[task_name]['act_mask'] = torch.from_numpy(np.stack(all_contexts[task_name]['act_mask'])).float().to(cfg.DEVICE)
        for key in all_buffers[task_name]:
            if key == 'task_name':
                continue
            else:
                all_buffers[task_name][key] = torch.cat(all_buffers[task_name][key], dim=0)
                print(all_buffers[task_name][key].shape)

    all_obs = np.concatenate([buf['obs'] for buf in all_buffers.values()], axis=0) 
    feat_dim = all_obs.shape[-1]
    obs_rms = {'proprioceptive': RunningMeanStd(shape=feat_dim)}

    rms = torch.load(cfg.MODEL.RMS_PATH, map_location='cpu')[1]
    obs_rms['proprioceptive'].mean = rms['proprioceptive'].mean.astype(np.float32)
    obs_rms['proprioceptive'].var = rms['proprioceptive'].var.astype(np.float32)
    
    for key, buffer in all_buffers.items():
        buffer['obs'] = np.clip(
            (buffer['obs'] - obs_rms['proprioceptive'].mean) / np.sqrt(obs_rms['proprioceptive'].var + 1e-8), 
            -10., 
            10.
        )

    if len(cfg.MODEL.PROPRIOCEPTIVE_OBS_TYPES) == 6:
        new_mean = obs_rms['proprioceptive'].mean.reshape(cfg.MODEL.MAX_LIMBS, -1)
        obs_rms['proprioceptive'].mean = new_mean[:, DEFAULT_OBS_DIM].ravel()
        new_var = obs_rms['proprioceptive'].var.reshape(cfg.MODEL.MAX_LIMBS, -1)
        obs_rms['proprioceptive'].var = new_var[:, DEFAULT_OBS_DIM].ravel()

    datasets = {}
    for task_name, buffer in all_buffers.items():
        datasets[task_name] = DistillationDataset(buffer, task_name=task_name)

    train_dataloader = MultiTaskDataLoader(
        datasets,
        batch_size=int(cfg.DISTILL.BATCH_SIZE / len(datasets)),
        shuffle=True,
        drop_last=False
    )

    v_param = []
    u_param = []
    tailor_param = []
    other_param = []
    for name, param in model.named_parameters():
        if 'other' in name:
            tailor_param.append(param)
        elif 'v_net' in name:
            v_param.append(param)
        elif 'u_net' in name:
            u_param.append(param)
        else:
            other_param.append(param)

    if cfg.DISTILL.OPTIMIZER == 'adam':
        optimizer = optim.Adam(
            [
                {"params": v_param, "lr": cfg.DISTILL.V_LR},
                {"params": u_param, "lr": cfg.DISTILL.U_LR},
                {"params": other_param, "lr": cfg.DISTILL.BASE_LR},
                {'params': tailor_param, "lr": cfg.DISTILL.TAILOR_LR},       
            ],
            eps=cfg.DISTILL.EPS, 
            weight_decay=cfg.DISTILL.WEIGHT_DECAY
        )

    elif cfg.DISTILL.OPTIMIZER == 'adamw':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=cfg.DISTILL.BASE_LR, 
            eps=cfg.DISTILL.EPS, 
            weight_decay=cfg.DISTILL.WEIGHT_DECAY
        )
    else:
        raise ValueError("Unsupported Optimizer Type")
    
    if cfg.DISTILL.USE_COSINE_SCHEDULER:
        scheduler = CosineAnnealingLR(optimizer, T_max=cfg.DISTILL.EPOCH_NUM, eta_min=0)

    def loss_function(obs_dict, act, act_mean, val, task_name):
        if cfg.DISTILL.IMITATION_TARGET == 'act':
            if cfg.MODEL.TRANSFORMER.SVD:
                if cfg.DISTILL.USE_CRITIC:
                    val_student, pi, logp, _, _, _, _, _ = model(obs_dict, act=act.to(cfg.DEVICE), compute_val=True, unimal_ids=unimal_ids, task_name=task_name)
                else:
                    val_student, pi, logp, _, _, _, _, _ = model(obs_dict, act=act.to(cfg.DEVICE), compute_val=False, unimal_ids=unimal_ids, task_name=task_name)
            else:
                if cfg.DISTILL.USE_CRITIC:
                    val_student, pi, logp, _, _, _ = model(obs_dict, act=act.to(cfg.DEVICE), compute_val=True, unimal_ids=unimal_ids)
                else:
                    val_student, pi, logp, _, _, _ = model(obs_dict, act=act.to(cfg.DEVICE), compute_val=False, unimal_ids=unimal_ids)
        else:
            if cfg.MODEL.TRANSFORMER.SVD:
                if cfg.DISTILL.USE_CRITIC:
                    val_student, pi, logp, _, _, _, _, _ = model(obs_dict, act=act_mean.to(cfg.DEVICE), compute_val=True, unimal_ids=unimal_ids, task_name=task_name)
                else:
                    val_student, pi, logp, _, _, _, _, _ = model(obs_dict, act=act_mean.to(cfg.DEVICE), compute_val=False, unimal_ids=unimal_ids, task_name=task_name)
            else:
                if cfg.DISTILL.USE_CRITIC:
                    val_student, pi, logp, _, _, _ = model(obs_dict, act=act_mean.to(cfg.DEVICE), compute_val=True, unimal_ids=unimal_ids)
                else:
                    val_student, pi, logp, _, _, _ = model(obs_dict, act=act_mean.to(cfg.DEVICE), compute_val=False, unimal_ids=unimal_ids)
        if cfg.DISTILL.LOSS_TYPE == 'KL':
            if cfg.DISTILL.KL_TARGET == 'act':
                target = act
            elif cfg.DISTILL.KL_TARGET == 'act_mean':
                target = act_mean
            else:
                raise ValueError("Unsupported loss type")
            if cfg.PPO.TANH == 'action':
                pred = torch.tanh(model.action_mu)
                target = torch.tanh(target)
            else:
                pred = model.action_mu
            if cfg.DISTILL.SAMPLE_WEIGHT:
                threshold = cfg.DISTILL.LARGE_ACT_DECAY
                w = torch.where(target.abs() > threshold, torch.exp(threshold - target.abs()), 1.).to(cfg.DEVICE)
                if cfg.DISTILL.BALANCED_LOSS:
                    loss = 0.5 * (((pred - target.to(cfg.DEVICE)).square() * w * (1 - obs_dict['act_padding_mask'])).sum(dim=1) / (1 - obs_dict['act_padding_mask']).sum(dim=1)).mean()
                else:
                    loss = 0.5 * ((pred - target.to(cfg.DEVICE)).square() * w * (1 - obs_dict['act_padding_mask'])).mean()
            else:
                if cfg.DISTILL.BALANCED_LOSS:
                    loss = 0.5 * (((pred - target.to(cfg.DEVICE)).square() * (1 - obs_dict['act_padding_mask'])).sum(dim=1) / (1 - obs_dict['act_padding_mask']).sum(dim=1)).mean()
                else:
                    loss = 0.5 * ((pred - target.to(cfg.DEVICE)).square() * (1 - obs_dict['act_padding_mask'])).mean()
        elif cfg.DISTILL.LOSS_TYPE == 'logp':
            if cfg.DISTILL.BALANCED_LOSS:
                loss = -((model.limb_logp * (1 - obs_dict['act_padding_mask'])).sum(dim=1, keepdim=True) / (1 - obs_dict['act_padding_mask']).sum(dim=1, keepdim=True)).mean()
            else:
                loss = -logp.mean()
        else:
            raise ValueError("Unsupported loss type")
        
        if cfg.DISTILL.USE_CRITIC:
            critic_loss = 0.5 * (val_student - val.to(cfg.DEVICE)).pow(2).mean()

            total_loss = loss + cfg.DISTILL.CRITIC_LOSS_WEIGHT * critic_loss
            return total_loss, loss, critic_loss

        else:
            total_loss = loss
            return total_loss, loss, None

    loss_curve, in_domain_valid_curve, out_domain_valid_curve = [], [], []
    for i in range(cfg.DISTILL.EPOCH_NUM):
        if i % cfg.DISTILL.SAVE_FREQ == 0:
            model.to('cpu')
            torch.save([model.state_dict(), obs_rms], f'{cfg.OUT_DIR}/checkpoint_{i}.pt')
            torch.save(optimizer.state_dict(), f'{cfg.OUT_DIR}/optimizer_{i}.pt')
            model.to(cfg.DEVICE)
        elif i <= 50:
            if i % 5 == 0:
                model.to('cpu')
                torch.save([model.state_dict(), obs_rms], f'{cfg.OUT_DIR}/checkpoint_{i}.pt')
                torch.save(optimizer.state_dict(), f'{cfg.OUT_DIR}/optimizer_{i}.pt')
                model.to(cfg.DEVICE)

        batch_losses = []
        batch_losses_mu = []
        batch_losses_v = []
        for batch_dict in train_dataloader:  
            optimizer.zero_grad()
            for task, (obs, train_act, train_act_mean, hfield, unimal_ids, train_val, goal, obj) in batch_dict.items():
                context = all_contexts[task]['context'][unimal_ids]
                obs_mask = all_contexts[task]['obs_mask'][unimal_ids]
                act_mask = all_contexts[task]['act_mask'][unimal_ids]
                obs = obs.to(cfg.DEVICE)

                if cfg.DISTILL.CONCAT_CONTEXT_TO_OBS:
                    batch_size = obs.shape[0]
                    merged_obs = torch.zeros(batch_size, cfg.MODEL.MAX_LIMBS, 52, device=cfg.DEVICE)
                    merged_obs[:, :, DEFAULT_OBS_DIM] = obs.reshape(batch_size, cfg.MODEL.MAX_LIMBS, -1)
                    merged_obs[:, :, DEFAULT_CONTEXT_DIM] = context.reshape(batch_size, cfg.MODEL.MAX_LIMBS, -1)
                    obs = merged_obs.reshape(batch_size, -1)

                train_obs_dict = {
                    'proprioceptive': obs,  
                    'context': context, 
                    'obs_padding_mask': obs_mask,  
                    'act_padding_mask': act_mask, 
                }
                if task in cfg.DISTILL.HFIELD_TASK:
                    train_obs_dict['hfield'] = hfield.to(cfg.DEVICE)
                if task in cfg.DISTILL.GOAL_TASK:
                    train_obs_dict['goal'] = goal.to(cfg.DEVICE)
                if task in cfg.DISTILL.OBJ_TASK:
                    train_obs_dict['obj'] = obj.to(cfg.DEVICE)

                loss, mu_loss, v_loss = loss_function(train_obs_dict, train_act, train_act_mean, train_val, task)
                batch_losses.append(loss.item())
                batch_losses_mu.append(mu_loss.item())
                if cfg.DISTILL.USE_CRITIC:
                    batch_losses_v.append(v_loss.item())
                loss.backward()
                
            if cfg.DISTILL.GRAD_NORM is not None:
                norm = nn.utils.clip_grad_norm_(model.parameters(), cfg.DISTILL.GRAD_NORM)

            optimizer.step()
        if cfg.DISTILL.USE_COSINE_SCHEDULER:
            print(f"Epoch {i}, LR:", scheduler.get_last_lr())
            scheduler.step()
            
        log_str = f'Epoch {i}, Average Batch Loss: {np.mean(batch_losses)}, mu_Loss: {np.mean(batch_losses_mu)}'
        if cfg.DISTILL.USE_CRITIC:
            log_str += f', v_loss: {np.mean(batch_losses_v)}'
        print(log_str)
        with open(cfg.OUT_DIR + "/log.txt", "a", encoding="utf-8") as f:
            f.write(log_str + "\n")
        
        params_norm = torch.norm(torch.cat([p.view(-1) for p in model.parameters()]), 2).item()
        log_str = f'model norm: {params_norm}'
        print(log_str)
        with open(cfg.OUT_DIR + "/log.txt", "a", encoding="utf-8") as f:
            f.write(log_str + "\n")
        loss_curve.append(np.mean(batch_losses))
        with open(f'{cfg.OUT_DIR}/loss_curve.pkl', 'wb') as f:
            pickle.dump([loss_curve, in_domain_valid_curve, out_domain_valid_curve], f)

    model.to('cpu')
    torch.save([model.state_dict(), obs_rms], f'{cfg.OUT_DIR}/checkpoint_{cfg.DISTILL.EPOCH_NUM}.pt')
