import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
import numpy as np
from tqdm import tqdm
from arguments import get_args
from augmentations import get_aug
from models import get_model_ori
from tools import AverageMeter, knn_monitor, Logger, file_exist_check
from datasets import get_dataset
from datetime import datetime
from utils.loggers import *
from utils.metrics import mask_classes
from utils.loggers import CsvLogger
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from typing import Tuple
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
import re

def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data +
                                (1 - tau) * target_param.data)

def schedule(schdl, step):
    try:
        return float(schdl)
    except ValueError:
        match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
        if match:
            init, final, duration = [float(g) for g in match.groups()]
            mix = np.clip(step / duration, 0.0, 1.0)
            return (1.0 - mix) * init + mix * final
        match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
        if match:
            init, final1, duration1, final2, duration2 = [
                float(g) for g in match.groups()
            ]
            if step <= duration1:
                mix = np.clip(step / duration1, 0.0, 1.0)
                return (1.0 - mix) * init + mix * final1
            else:
                mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
                return (1.0 - mix) * final1 + mix * final2
    raise NotImplementedError(schdl)

class TruncatedNormal(pyd.Normal):
    def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
        super().__init__(loc, scale, validate_args=False)
        self.low = low
        self.high = high
        self.eps = eps

    def _clamp(self, x):
        clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
        x = x - x.detach() + clamped_x.detach()
        return x

    def sample(self, clip=None, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape,
                               dtype=self.loc.dtype,
                               device=self.loc.device)
        eps *= self.scale
        if clip != 'None':
            eps = torch.clamp(eps, -clip, clip)
        x = self.loc + eps
        return self._clamp(x)

class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        return mu


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

    def forward(self, obs, action):
        h = self.trunk(obs)
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        return q1, q2


def evaluate(model: ContinualModel, dataset: ContinualDataset, device, classifier=None) -> Tuple[list, list]:
    """
    Evaluates the accuracy of the model for each past task.
    :param model: the model to be evaluated
    :param dataset: the continual dataset at hand
    :return: a tuple of lists, containing the class-il
             and task-il accuracy for each task
    """
    status = model.training
    model.eval()
    accs, accs_mask_classes = [], []
    for k, test_loader in enumerate(dataset.test_loaders):
        correct, correct_mask_classes, total = 0.0, 0.0, 0.0
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            if classifier is not None:
                outputs = classifier(outputs)

            _, pred = torch.max(outputs.data, 1)
            correct += torch.sum(pred == labels).item()
            total += labels.shape[0]

            if dataset.SETTING == 'class-il':
                mask_classes(outputs, dataset, k)
                _, pred = torch.max(outputs.data, 1)
                correct_mask_classes += torch.sum(pred == labels).item()
        
        accs.append(correct / total * 100)
        accs_mask_classes.append(correct_mask_classes / total * 100)

    model.train(status)
    return accs, accs_mask_classes

def penalty(model, big_omega, checkpoint, device):
    if big_omega is None:
        return torch.tensor(0.0).to(device)
    else:
        penalty = (big_omega * ((model.net.module.backbone.get_params() - checkpoint) ** 2)).sum()
        return penalty

def end_task(model, big_omega, small_omega, checkpoint, device, xi, dataset):
    # big omega calculation step
    if big_omega is None:
        big_omega = torch.zeros_like(model.net.module.backbone.get_params()).to(device)

    big_omega = small_omega / ((model.net.module.backbone.get_params().data - checkpoint) ** 2 + xi)

    checkpoint = model.net.module.backbone.get_params().data.clone().to(device)
    small_omega = 0
    return big_omega, small_omega, checkpoint

def D(p, z, version='simplified'): # negative cosine similarity
    if version == 'original':
        z = z.detach() # stop gradient
        p = F.normalize(p, dim=1) # l2-normalize 
        z = F.normalize(z, dim=1) # l2-normalize 
        return -(p*z).sum(dim=1).mean()

    elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__
        return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
    else:
        raise Exception

def main(device, args):

    dataset = get_dataset(args)
    dataset_copy = get_dataset(args)
    train_loader, memory_loader, test_loader = dataset_copy.get_data_loaders(args)

    # define model
    model = get_model_ori(args, device, len(train_loader), dataset.get_transform(args))
    
    # Actor Critic hyper
    repr_dim = 512
    feature_dim = 1024
    hidden_dim = 512
    action_shape = [args.train.action]
    actor = Actor(repr_dim, action_shape, feature_dim,
                        hidden_dim).to(device)
    critic = Critic(repr_dim, action_shape, feature_dim,
                            hidden_dim).to(device)
    critic_target = Critic(repr_dim, action_shape, feature_dim,
                            hidden_dim).to(device) 
    critic_target.load_state_dict(critic.state_dict())                                              
    actor_opt = torch.optim.Adam(actor.parameters(), lr=args.train.ac_lr)
    critic_opt = torch.optim.Adam(critic.parameters(), lr=args.train.ac_lr)
    logger = Logger(matplotlib=args.logger.matplotlib, log_dir=args.log_dir)


    global_step = 0
    soft = nn.Softmax(dim=-1)
    for t in range(dataset.N_TASKS):
        train_loader, memory_loader, test_loader = dataset.get_data_loaders(args)
        global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training')

        act = 1
        max_acc = 0
        for epoch in global_progress:
            model.train()
            results, results_mask_classes = [], []

            local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
            for idx, ((images1, images2, notaug_images), labels) in enumerate(local_progress):
                global_step += 1
                outputs = model.net.module.backbone(images1.to(device))
                if model.buffer.is_empty():
                    data_dict = model.net.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True))
                    loss = data_dict['loss'].mean()
                    data_dict['loss'] = data_dict['loss'].mean()
                else:
                    buf_inputs, buf_logits, buf_inputs1, buf_inputs2 = model.buffer.get_data(
                        args.train.batch_size, transform=model.transform)
                    lam = np.array(np.linspace(args.train.minact, args.train.maxact, action_shape[0]))[act]
                    mixed_x = lam * images1.to(device) + (1 - lam) * buf_inputs[:images1.shape[0]].to(device)
                    mixed_x_aug = lam * images2.to(device) + (1 - lam) * buf_inputs1[:images1.shape[0]].to(device)
                    data_dict = model.net.forward(mixed_x.to(device, non_blocking=True), mixed_x_aug.to(device, non_blocking=True))
                    loss = data_dict['loss'].mean()
                    data_dict['loss'] = data_dict['loss'].mean()                    

                # add buffer
                model.buffer.add_data(examples=notaug_images, logits=outputs.data, inputs1=images1,inputs2=images2)
                model.opt.zero_grad()
                loss.backward()
                model.opt.step()   

                # Update Actor
                stddev_clip = args.train.stddev_clip
                discount = args.train.discount
                stddev = schedule(1, global_step)
                action = actor(outputs.detach(), stddev)
                score = soft(action)
                score = torch.clamp(score, 0.08, 0.12)
                actions = torch.multinomial(score, 1)

                if idx % args.train.rectify_step == 0:
                    count_action = torch.zeros(action_shape[0])
                for a in actions:
                    for i in range(action_shape[0]):
                        if a == i:
                            count_action[i] += 1
                if idx % args.train.rectify_step == 0:
                    act = torch.argmax(count_action)
                Q1, Q2 = critic(outputs.detach(), score.detach())
                Q = torch.min(Q1, Q2)
                actor_loss = -Q.mean()
                data_dict['actor_loss'] = actor_loss
                actor_opt.zero_grad(set_to_none=True)
                actor_loss.backward()
                actor_opt.step()

                # Update Critic 
                buf_inputs, buf_logits, buf_inputs1, buf_inputs2 = model.buffer.get_data(
                    args.train.batch_size, transform=model.transform)
                next_data_dict_buf = model.observe(buf_inputs, labels, buf_inputs1, notaug_images)
                next_data_dict = model.observe(images1, labels, images2, notaug_images)
                next_loss = next_data_dict['loss'] + 0.1 * next_data_dict_buf['loss']
                with torch.no_grad():
                    next_obs = model.net.module.backbone(images1.to(device))
                    stddev = schedule(1, global_step)
                    next_action = actor(next_obs, stddev)
                    next_action = soft(next_action)
                    target_Q1, target_Q2 = critic_target(next_obs, next_action)
                    target_V = torch.min(target_Q1, target_Q2)
                    reward = - args.train.reward_lam *next_loss
                    # done signal
                    if idx == args.train.maxidx:
                        done_signal = 0
                    else:
                        done_signal = 1
                    target_Q = reward + (done_signal * discount * target_V)  
                Q1, Q2 = critic(outputs.detach(), action.detach())
                critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)        
                model.opt.zero_grad(set_to_none=True)
                critic_opt.zero_grad(set_to_none=True)
                next_loss.backward()
                critic_loss.backward()
                critic_opt.step()
                model.opt.step()

                # Update target Critic
                soft_update_params(critic, critic_target, 0.01)
                data_dict['act'] = act
                logger.update_scalers(data_dict)
            global_progress.set_postfix(data_dict)

            if args.train.knn_monitor and epoch % args.train.knn_interval == 0: 
                for i in range(len(dataset.test_loaders)):
                    acc, acc_mask = knn_monitor(model.net.module.backbone, dataset, dataset.memory_loaders[i], dataset.test_loaders[i], device, args.cl_default, task_id=t, k=min(args.train.knn_k, len(memory_loader.dataset))) 
                    results.append(acc)
                mean_acc = np.mean(results)
            epoch_dict = {"epoch":epoch, "accuracy": mean_acc}
            global_progress.set_postfix(epoch_dict)
            logger.update_scalers(epoch_dict)
      
            if t == dataset.N_TASKS-1 and max_acc <= mean_acc:
                max_acc = mean_acc  
                model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth")
                torch.save({
                'epoch': epoch+1,
                'state_dict':model.net.state_dict()
                }, model_path)
                print(f"Task Model saved to {model_path}")
                with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
                    f.write(f'{model_path}')
        if t < dataset.N_TASKS-1: 
            model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth")
            torch.save({
            'epoch': epoch+1,
            'state_dict':model.net.state_dict()
            }, model_path)
            print(f"Task Model saved to {model_path}")
            with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
                f.write(f'{model_path}')
        else:
            model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}_final.pth")
            torch.save({
                'epoch': epoch+1,
                'state_dict': model.net.state_dict()}, model_path)
            print(f"Task Model saved to {model_path}")
    if args.eval is not False and args.cl_default is False:
        args.eval_from = model_path

if __name__ == "__main__":
    args = get_args()
    main(device=args.device, args=args)
    completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed')
    os.rename(args.log_dir, completed_log_dir)
    print(f'Log file has been saved to {completed_log_dir}')


