package io.quarkiverse.mcp.server.runtime;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import jakarta.enterprise.inject.Any;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Singleton;

import org.jboss.logging.Logger;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.quarkiverse.mcp.server.FilterContext;
import io.quarkiverse.mcp.server.Icon;
import io.quarkiverse.mcp.server.IconsProvider;
import io.quarkiverse.mcp.server.JsonRpcErrorCodes;
import io.quarkiverse.mcp.server.McpException;
import io.quarkiverse.mcp.server.McpLog;
import io.quarkiverse.mcp.server.McpMethod;
import io.quarkiverse.mcp.server.MetaKey;
import io.quarkiverse.mcp.server.PromptFilter;
import io.quarkiverse.mcp.server.PromptManager;
import io.quarkiverse.mcp.server.PromptManager.PromptInfo;
import io.quarkiverse.mcp.server.PromptResponse;
import io.quarkus.arc.All;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Vertx;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;

@Singleton
public class PromptManagerImpl extends FeatureManagerBase<PromptResponse, PromptInfo> implements PromptManager {

    private static final Logger LOG = Logger.getLogger(PromptManagerImpl.class);

    final ConcurrentMap<String, PromptInfo> prompts;

    final List<PromptFilter> filters;

    final Instance<IconsProvider> iconsProviders;

    PromptManagerImpl(McpMetadata metadata,
            Vertx vertx,
            ObjectMapper mapper,
            ConnectionManager connectionManager,
            Instance<CurrentIdentityAssociation> currentIdentityAssociation,
            ResponseHandlers responseHandlers,
            @All List<PromptFilter> filters,
            @Any Instance<IconsProvider> iconsProviders) {
        super(vertx, mapper, connectionManager, currentIdentityAssociation, responseHandlers);
        this.prompts = new ConcurrentHashMap<>();
        for (FeatureMetadata<PromptResponse> f : metadata.prompts()) {
            this.prompts.put(f.info().name(), new PromptMethod(f, iconsProviders));
        }
        this.filters = filters;
        this.iconsProviders = iconsProviders;
    }

    @Override
    Stream<PromptInfo> infos() {
        return prompts.values().stream();
    }

    @Override
    Stream<PromptInfo> filter(Stream<PromptInfo> infos, FilterContext filterContext) {
        return infos.filter(p -> test(p, filterContext));
    }

    @Override
    protected McpMethod mcpListMethod() {
        return McpMethod.PROMPTS_GET;
    }

    @Override
    public PromptInfo getPrompt(String name) {
        return prompts.get(Objects.requireNonNull(name));
    }

    @Override
    public PromptDefinition newPrompt(String name) {
        if (prompts.containsKey(name)) {
            throw promptWithNameAlreadyExists(name);
        }
        return new PromptDefinitionImpl(name);
    }

    @Override
    public PromptInfo removePrompt(String name) {
        AtomicReference<PromptInfo> removed = new AtomicReference<>();
        prompts.computeIfPresent(name, (key, value) -> {
            if (!value.isMethod()) {
                removed.set(value);
                notifyConnections(McpMethod.NOTIFICATIONS_PROMPTS_LIST_CHANGED);
                return null;
            }
            return value;
        });
        return removed.get();
    }

    IllegalArgumentException promptWithNameAlreadyExists(String name) {
        return new IllegalArgumentException("A prompt with name [" + name + "] already exits");
    }

    @Override
    protected McpException notFound(String id) {
        return new McpException("Invalid prompt name: " + id, JsonRpcErrorCodes.INVALID_PARAMS);
    }

    @SuppressWarnings("unchecked")
    @Override
    protected FeatureInvoker<PromptResponse> getInvoker(String id, McpRequest mcpRequest, JsonObject message) {
        PromptInfo prompt = prompts.get(id);
        if (prompt instanceof FeatureInvoker fi
                && matchesServer(prompt, mcpRequest)
                && test(prompt, FilterContextImpl.of(McpMethod.PROMPTS_GET, message, mcpRequest))) {
            return fi;
        }
        return null;
    }

    private boolean test(PromptInfo prompt, FilterContext filterContext) {
        if (filters.isEmpty()) {
            return true;
        }
        for (PromptFilter filter : filters) {
            try {
                if (!filter.test(prompt, filterContext)) {
                    return false;
                }
            } catch (RuntimeException e) {
                LOG.errorf(e, "Unable to apply filter: %s", filter);
            }
        }
        return true;
    }

    class PromptMethod extends FeatureMetadataInvoker<PromptResponse> implements PromptManager.PromptInfo {

        private PromptMethod(FeatureMetadata<PromptResponse> metadata, Instance<IconsProvider> iconsProviders) {
            super(metadata, iconsProviders);
        }

        @Override
        public String name() {
            return metadata.info().name();
        }

        @Override
        public String title() {
            return metadata.info().title();
        }

        @Override
        public String description() {
            return metadata.info().description();
        }

        @Override
        public String serverName() {
            return metadata.info().serverName();
        }

        @Override
        public Map<MetaKey, Object> metadata() {
            return metadata.info().metadata().entrySet()
                    .stream()
                    .collect(Collectors.toUnmodifiableMap(e -> MetaKey.from(e.getKey()), e -> Json.decodeValue(e.getValue())));
        }

        @Override
        public boolean isMethod() {
            return true;
        }

        @Override
        public List<PromptArgument> arguments() {
            return metadata.info().serializedArguments().stream()
                    .map(a -> new PromptArgument(a.name(), a.title(), a.description(), a.required(), a.defaultValue()))
                    .toList();
        }

        @Override
        public JsonObject asJson() {
            JsonObject prompt = metadata.asJson();
            if (iconsProvider != null) {
                try {
                    List<Icon> icons = iconsProvider.get(this);
                    prompt.put("icons", icons);
                } catch (Exception e) {
                    LOG.errorf(e, "Unable to get icons for %s", name());
                }
            }
            return prompt;
        }

    }

    class PromptDefinitionImpl
            extends FeatureManagerBase.FeatureDefinitionBase<PromptInfo, PromptArguments, PromptResponse, PromptDefinitionImpl>
            implements PromptManager.PromptDefinition {

        private String title;
        private final List<PromptArgument> arguments;
        private Map<MetaKey, Object> metadata = Map.of();

        PromptDefinitionImpl(String name) {
            super(name);
            this.arguments = new ArrayList<>();
        }

        @Override
        public PromptDefinition setTitle(String title) {
            this.title = title;
            return this;
        }

        @Override
        public PromptDefinition addArgument(String name, String title, String description, boolean required,
                String defaultValue) {
            arguments.add(new PromptArgument(name, title, description, required, defaultValue));
            return this;
        }

        @Override
        public PromptDefinition setMetadata(Map<MetaKey, Object> metadata) {
            this.metadata = metadata;
            return this;
        }

        @Override
        public PromptInfo register() {
            validate();
            PromptDefinitionInfo ret = new PromptDefinitionInfo(name, title, description, serverName, fun, asyncFun,
                    runOnVirtualThread, arguments, metadata, icons);
            PromptInfo existing = prompts.putIfAbsent(name, ret);
            if (existing != null) {
                throw promptWithNameAlreadyExists(name);
            } else {
                notifyConnections(McpMethod.NOTIFICATIONS_PROMPTS_LIST_CHANGED);
            }
            return ret;
        }
    }

    class PromptDefinitionInfo extends FeatureManagerBase.FeatureDefinitionInfoBase<PromptArguments, PromptResponse>
            implements PromptManager.PromptInfo {

        private final String title;
        private final List<PromptArgument> arguments;
        private final Map<MetaKey, Object> metadata;

        private PromptDefinitionInfo(String name, String title, String description, String serverName,
                Function<PromptArguments, PromptResponse> fun,
                Function<PromptArguments, Uni<PromptResponse>> asyncFun, boolean runOnVirtualThread,
                List<PromptArgument> arguments, Map<MetaKey, Object> metadata, List<Icon> icons) {
            super(name, description, serverName, fun, asyncFun, runOnVirtualThread, icons);
            this.title = title;
            this.arguments = List.copyOf(arguments);
            this.metadata = Map.copyOf(metadata);
        }

        @Override
        public String title() {
            return title;
        }

        @Override
        public List<PromptArgument> arguments() {
            return arguments;
        }

        @Override
        public Map<MetaKey, Object> metadata() {
            return metadata;
        }

        @Override
        protected PromptArguments createArguments(ArgumentProviders argumentProviders) {
            Map<String, String> args = new HashMap<>();
            for (Entry<String, Object> e : argumentProviders.args().entrySet()) {
                args.put(e.getKey(), e.getValue().toString());
            }
            for (PromptArgument a : arguments) {
                if (a.defaultValue() != null && !args.containsKey(a.name())) {
                    args.put(a.name(), a.defaultValue());
                }
            }
            return new PromptArgumentsImpl(argumentProviders, args,
                    log(Feature.PROMPT.toString().toLowerCase() + ":" + name, name, argumentProviders));
        }

        @Override
        public JsonObject asJson() {
            JsonObject prompt = new JsonObject()
                    .put("name", name())
                    .put("description", description());
            if (title != null) {
                prompt.put("title", title);
            }
            JsonArray arguments = new JsonArray();
            for (PromptArgument a : this.arguments) {
                JsonObject arg = new JsonObject()
                        .put("name", a.name())
                        .put("description", a.description())
                        .put("required", a.required());
                if (a.title() != null) {
                    arg.put("title", a.title());
                }
                arguments.add(arg);
            }
            prompt.put("arguments", arguments);
            if (!metadata.isEmpty()) {
                JsonObject meta = new JsonObject();
                for (Map.Entry<MetaKey, Object> e : metadata.entrySet()) {
                    meta.put(e.getKey().toString(), e.getValue());
                }
                prompt.put("_meta", meta);
            }
            if (icons != null) {
                prompt.put("icons", icons);
            }
            return prompt;
        }

    }

    static class PromptArgumentsImpl extends AbstractRequestFeatureArguments implements PromptArguments {

        private final Map<String, String> args;
        private final McpLog log;

        PromptArgumentsImpl(ArgumentProviders argProviders, Map<String, String> args, McpLog log) {
            super(argProviders);
            this.args = Map.copyOf(args);
            this.log = log;
        }

        @Override
        public McpLog log() {
            return log;
        }

        @Override
        public Map<String, String> args() {
            return args;
        }

    }

}
